From 17eb649d5596c52bae65a069e03b4550155ad57f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 22 Jun 2025 13:51:41 -0700 Subject: [PATCH] Implement guard collectives (optimized version) (#156562) This is a remix of https://github.com/pytorch/pytorch/pull/155558 Instead of mediating guard collective via a config option, in this one it's done via a `set_stance` like API. The motivation is that checking for the config value on entry on torch.compile is apparently quite expensive, according to functorch_maml_omniglot. So this makes it a bit cheaper. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/156562 Approved by: https://github.com/Microve --- docs/source/torch.compiler_api.md | 1 + test/distributed/test_dynamo_distributed.py | 35 +++++++++++++++++++++ torch/_C/_dynamo/eval_frame.pyi | 12 +++++-- torch/_dynamo/decorators.py | 1 + torch/_dynamo/distributed.py | 13 ++++++++ torch/_dynamo/eval_frame.py | 35 ++++++++++++++++++++- torch/_dynamo/types.py | 7 +++++ torch/compiler/__init__.py | 30 ++++++++++++++++++ torch/csrc/dynamo/eval_frame.c | 18 +++++++++++ torch/csrc/dynamo/eval_frame_cpp.cpp | 22 ++++++++++++- 10 files changed, 170 insertions(+), 4 deletions(-) diff --git a/docs/source/torch.compiler_api.md b/docs/source/torch.compiler_api.md index 15156cab43dd..2b79b0e67007 100644 --- a/docs/source/torch.compiler_api.md +++ b/docs/source/torch.compiler_api.md @@ -21,6 +21,7 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`. list_backends disable set_stance + set_enable_guard_collectives cudagraph_mark_step_begin is_compiling is_dynamo_compiling diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 8446282c84ff..73ac6eb0da7b 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -25,6 +25,7 @@ from torch._dynamo.comptime import comptime from torch._dynamo.testing import collect_results from torch._dynamo.utils import same from torch._higher_order_ops.wrap import tag_activation_checkpoint +from torch.compiler import set_enable_guard_collectives from torch.distributed._functional_collectives import _maybe_wrap_tensor from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import ( @@ -61,6 +62,15 @@ def init_weights(m): m.bias.data.fill_(0.01) +@contextmanager +def enable_guard_collectives(): + old = set_enable_guard_collectives(True) + try: + yield + finally: + set_enable_guard_collectives(old) + + class ToyModel(nn.Module): def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None): super().__init__() @@ -1141,6 +1151,31 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase): for r in res[1:]: self.assertEqual(res[0], r) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @enable_guard_collectives() + def test_guard_collective(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(x): + return x.sum() + + x = torch.randn(10, device=self.rank) + f(x) + + if self.rank == 0: + x = torch.randn(10, device=self.rank) + else: + x = torch.randn(12, device=self.rank) # recompile on one rank + f(x) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_get_pg_attr(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index c89de9a1ff9f..05dde69b0470 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,8 +1,13 @@ import enum import types -from typing import overload +from typing import Optional, overload -from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn +from torch._dynamo.types import ( + DynamoCallback, + DynamoGuardCompleteHook, + DynamoGuardHook, + GuardFn, +) def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def set_skip_guard_eval_unsafe(value: bool) -> bool: ... @@ -13,6 +18,9 @@ def set_code_exec_strategy( code: types.CodeType, strategy: _FrameExecStrategy ) -> None: ... def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_guard_complete_hook( + hook: Optional[DynamoGuardCompleteHook], +) -> Optional[DynamoGuardCompleteHook]: ... def raise_sigtrap() -> None: ... class _CacheEntry: diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 0ab8f52311ae..9708eda2e039 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from torch._C._dynamo.eval_frame import ( # noqa: F401 reset_code, set_eval_frame, + set_guard_complete_hook, set_guard_error_hook, unsupported, ) diff --git a/torch/_dynamo/distributed.py b/torch/_dynamo/distributed.py index aa60b325844b..490b6330fafa 100644 --- a/torch/_dynamo/distributed.py +++ b/torch/_dynamo/distributed.py @@ -22,6 +22,7 @@ from . import config _COMPILE_PG: Optional[dist.ProcessGroup] = None +_GUARD_PG: Optional[dist.ProcessGroup] = None def get_compile_pg() -> Optional[dist.ProcessGroup]: @@ -39,3 +40,15 @@ def get_compile_pg() -> Optional[dist.ProcessGroup]: return _COMPILE_PG return None + + +# NB: Unlike get_compile_pg, this is only called when guard collectives were +# explicitly requested +def get_guard_pg() -> Optional[dist.ProcessGroup]: + if dist.is_available() and dist.is_initialized(): + global _GUARD_PG + if _GUARD_PG is None: + _GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg") + return _GUARD_PG + + return None diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 49e8de7464c9..06b3facc5b9e 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -58,6 +58,7 @@ from torch._C._dynamo.eval_frame import ( # noqa: F401 reset_code, set_code_exec_strategy, set_eval_frame, + set_guard_complete_hook, set_guard_error_hook, set_skip_guard_eval_unsafe, unsupported, @@ -90,7 +91,7 @@ from torch.fx.experimental.symbolic_shapes import ( ) from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo -from . import config, convert_frame, external_utils, trace_rules, utils +from . import config, convert_frame, distributed, external_utils, trace_rules, utils from .backends.registry import CompilerFn, lookup_backend from .code_context import code_context from .exc import ( @@ -519,6 +520,38 @@ def _log_traced_frames(): log.info(msg) +def guard_collectives_hook(guard_eval_result): + import torch.distributed as dist + from torch._dynamo.utils import dynamo_timed + + # guard_eval_result == True ==> cache hit + if pg := distributed.get_guard_pg(): + with dynamo_timed( + "guard_collective", log_pt2_compile_event=True, log_waitcounter=True + ): + log.info("guard_collective %s", guard_eval_result) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "guard_collective", + "encoding": "string", + }, + payload_fn=lambda: str(guard_eval_result), + ) + # TODO: a bit awkward to time, this isn't inside of the dynamo compile region + all_results = [None] * pg.size() + dist.all_gather_object(all_results, guard_eval_result, group=pg) + # True = everyone hit, OK to run + # False = someone missed, force recompile everywhere + res = all(all_results) + log.info("guard_collective %s -> %s", guard_eval_result, res) + return res + return guard_eval_result + + +_not_set = object() + + class _TorchDynamoContext: def __init__( self, diff --git a/torch/_dynamo/types.py b/torch/_dynamo/types.py index 769df78937b3..fc9bc601fd63 100644 --- a/torch/_dynamo/types.py +++ b/torch/_dynamo/types.py @@ -114,6 +114,13 @@ class DynamoGuardHook(Protocol): ) -> None: ... +class DynamoGuardCompleteHook(Protocol): + def __call__( + self, + cache_hit: bool, + ) -> bool: ... + + class ProfilerStartHook(Protocol): def __call__( self, diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index de3267502f28..ff232da14448 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -21,6 +21,7 @@ __all__ = [ "list_backends", "disable", "set_stance", + "set_enable_guard_collectives", "cudagraph_mark_step_begin", "wrap_numpy", "is_compiling", @@ -330,6 +331,35 @@ def set_stance( set_stance._dynamo_forbidden = True # type: ignore[attr-defined] +def set_enable_guard_collectives(enabled: bool): + """ + Enables use of collectives *during* guard evaluation to synchronize behavior + across ranks. This is expensive: we have to issue a collective every time + we enter a compiled code region, even if no rank actually would need to + compile. This can help prevent NCCL hangs by ensuring that we never have a + situation where one rank starts recompiling while other ranks don't compile; + it is especially useful in conjunction with enable_compiler_collectives + where such a situation would immediately cause a hang (as it is necessary + for all ranks to compile at the same time to run compiler collectives). Like + compiler collectives, you can only run this on SPMD programs; you will hang + otherwise. Note that a guard collective is only issued if there is any + compiled code to guard on; if this the first time we encounter a frame or + the frame is skipped, we don't issue collectives. + + Returns the previous setting of enabled. + """ + from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401 + from torch._dynamo.eval_frame import guard_collectives_hook + + if enabled: + return set_guard_complete_hook(guard_collectives_hook) is not None + else: + return set_guard_complete_hook(None) is not None + + +set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined] + + def cudagraph_mark_step_begin(): """ Indicates that a new iteration of inference or training is about to begin. diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index cb6615c8aca6..f413782b2d30 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -11,6 +11,7 @@ #include PyObject* guard_error_hook = NULL; +PyObject* guard_complete_hook = NULL; typedef struct { int active_dynamo_threads; @@ -626,6 +627,22 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) { Py_RETURN_NONE; } +static PyObject* set_guard_complete_hook(PyObject* dummy, PyObject* obj) { + PyObject* old_hook = guard_complete_hook; + + if (obj == Py_None) { + obj = NULL; + } + + guard_complete_hook = Py_XNewRef(obj); + + if (old_hook == NULL) { + Py_RETURN_NONE; + } else { + return old_hook; + } +} + // Debugging function for GNU C only. // Used to set gdb breakpoints in hot CPython sites from Python. // Code example: @@ -666,6 +683,7 @@ static PyMethodDef _methods[] = { {"unsupported", unsupported, METH_VARARGS, NULL}, {"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, + {"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL}, {"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL}, {NULL, NULL, 0, NULL}}; diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index 9ec54b46b97a..e05de24259e0 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -7,6 +7,10 @@ #include #include +extern "C" { +extern PyObject* guard_complete_hook; +} + static constexpr const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup"; @@ -197,7 +201,23 @@ PyObject* dynamo__custom_eval_frame( // guard eval failed, keep propagating fail(); return eval_result; - } else if (maybe_cached_code != Py_None) { + } + + // NB: We only do guard collectives when there are any compiled code entries + // at all; these reduces overtriggering and we don't need to do guard + // collectives the very first time we've seen a frame + // TODO: We could also check if we had just created extra for the first + // time? Not too sure the best condition for extra->cache_entry_list + if (guard_complete_hook != nullptr && !extra->cache_entry_list.empty()) { + py::handle guard_complete_hook_handle(guard_complete_hook); + // False means force compilation (someone cache missed) + py::object res = guard_complete_hook_handle(maybe_cached_code != Py_None); + if (!py::cast(res)) { + maybe_cached_code = Py_None; // NB: non-owning + } + } + + if (maybe_cached_code != Py_None) { cached_code = (PyCodeObject*)maybe_cached_code; // used cached version DEBUG_TRACE("cache hit %s", get_frame_name(frame));