Implement guard collectives (optimized version) (#156562)

This is a remix of https://github.com/pytorch/pytorch/pull/155558

Instead of mediating guard collective via a config option, in this one it's done via a `set_stance` like API. The motivation is that checking for the config value on entry on torch.compile is apparently quite expensive, according to functorch_maml_omniglot. So this makes it a bit cheaper.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156562
Approved by: https://github.com/Microve
This commit is contained in:
Edward Z. Yang
2025-06-22 13:51:41 -07:00
committed by PyTorch MergeBot
parent 73772919d2
commit 17eb649d55
10 changed files with 170 additions and 4 deletions

View File

@ -21,6 +21,7 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
list_backends
disable
set_stance
set_enable_guard_collectives
cudagraph_mark_step_begin
is_compiling
is_dynamo_compiling

View File

@ -25,6 +25,7 @@ from torch._dynamo.comptime import comptime
from torch._dynamo.testing import collect_results
from torch._dynamo.utils import same
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.compiler import set_enable_guard_collectives
from torch.distributed._functional_collectives import _maybe_wrap_tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
@ -61,6 +62,15 @@ def init_weights(m):
m.bias.data.fill_(0.01)
@contextmanager
def enable_guard_collectives():
old = set_enable_guard_collectives(True)
try:
yield
finally:
set_enable_guard_collectives(old)
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
super().__init__()
@ -1141,6 +1151,31 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@enable_guard_collectives()
def test_guard_collective(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x):
return x.sum()
x = torch.randn(10, device=self.rank)
f(x)
if self.rank == 0:
x = torch.randn(10, device=self.rank)
else:
x = torch.randn(12, device=self.rank) # recompile on one rank
f(x)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_get_pg_attr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):

View File

@ -1,8 +1,13 @@
import enum
import types
from typing import overload
from typing import Optional, overload
from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn
from torch._dynamo.types import (
DynamoCallback,
DynamoGuardCompleteHook,
DynamoGuardHook,
GuardFn,
)
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
@ -13,6 +18,9 @@ def set_code_exec_strategy(
code: types.CodeType, strategy: _FrameExecStrategy
) -> None: ...
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
def set_guard_complete_hook(
hook: Optional[DynamoGuardCompleteHook],
) -> Optional[DynamoGuardCompleteHook]: ...
def raise_sigtrap() -> None: ...
class _CacheEntry:

View File

@ -44,6 +44,7 @@ if TYPE_CHECKING:
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_eval_frame,
set_guard_complete_hook,
set_guard_error_hook,
unsupported,
)

View File

@ -22,6 +22,7 @@ from . import config
_COMPILE_PG: Optional[dist.ProcessGroup] = None
_GUARD_PG: Optional[dist.ProcessGroup] = None
def get_compile_pg() -> Optional[dist.ProcessGroup]:
@ -39,3 +40,15 @@ def get_compile_pg() -> Optional[dist.ProcessGroup]:
return _COMPILE_PG
return None
# NB: Unlike get_compile_pg, this is only called when guard collectives were
# explicitly requested
def get_guard_pg() -> Optional[dist.ProcessGroup]:
if dist.is_available() and dist.is_initialized():
global _GUARD_PG
if _GUARD_PG is None:
_GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg")
return _GUARD_PG
return None

View File

@ -58,6 +58,7 @@ from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_code_exec_strategy,
set_eval_frame,
set_guard_complete_hook,
set_guard_error_hook,
set_skip_guard_eval_unsafe,
unsupported,
@ -90,7 +91,7 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from . import config, convert_frame, external_utils, trace_rules, utils
from . import config, convert_frame, distributed, external_utils, trace_rules, utils
from .backends.registry import CompilerFn, lookup_backend
from .code_context import code_context
from .exc import (
@ -519,6 +520,38 @@ def _log_traced_frames():
log.info(msg)
def guard_collectives_hook(guard_eval_result):
import torch.distributed as dist
from torch._dynamo.utils import dynamo_timed
# guard_eval_result == True ==> cache hit
if pg := distributed.get_guard_pg():
with dynamo_timed(
"guard_collective", log_pt2_compile_event=True, log_waitcounter=True
):
log.info("guard_collective %s", guard_eval_result)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "guard_collective",
"encoding": "string",
},
payload_fn=lambda: str(guard_eval_result),
)
# TODO: a bit awkward to time, this isn't inside of the dynamo compile region
all_results = [None] * pg.size()
dist.all_gather_object(all_results, guard_eval_result, group=pg)
# True = everyone hit, OK to run
# False = someone missed, force recompile everywhere
res = all(all_results)
log.info("guard_collective %s -> %s", guard_eval_result, res)
return res
return guard_eval_result
_not_set = object()
class _TorchDynamoContext:
def __init__(
self,

View File

@ -114,6 +114,13 @@ class DynamoGuardHook(Protocol):
) -> None: ...
class DynamoGuardCompleteHook(Protocol):
def __call__(
self,
cache_hit: bool,
) -> bool: ...
class ProfilerStartHook(Protocol):
def __call__(
self,

View File

@ -21,6 +21,7 @@ __all__ = [
"list_backends",
"disable",
"set_stance",
"set_enable_guard_collectives",
"cudagraph_mark_step_begin",
"wrap_numpy",
"is_compiling",
@ -330,6 +331,35 @@ def set_stance(
set_stance._dynamo_forbidden = True # type: ignore[attr-defined]
def set_enable_guard_collectives(enabled: bool):
"""
Enables use of collectives *during* guard evaluation to synchronize behavior
across ranks. This is expensive: we have to issue a collective every time
we enter a compiled code region, even if no rank actually would need to
compile. This can help prevent NCCL hangs by ensuring that we never have a
situation where one rank starts recompiling while other ranks don't compile;
it is especially useful in conjunction with enable_compiler_collectives
where such a situation would immediately cause a hang (as it is necessary
for all ranks to compile at the same time to run compiler collectives). Like
compiler collectives, you can only run this on SPMD programs; you will hang
otherwise. Note that a guard collective is only issued if there is any
compiled code to guard on; if this the first time we encounter a frame or
the frame is skipped, we don't issue collectives.
Returns the previous setting of enabled.
"""
from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401
from torch._dynamo.eval_frame import guard_collectives_hook
if enabled:
return set_guard_complete_hook(guard_collectives_hook) is not None
else:
return set_guard_complete_hook(None) is not None
set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined]
def cudagraph_mark_step_begin():
"""
Indicates that a new iteration of inference or training is about to begin.

View File

@ -11,6 +11,7 @@
#include <torch/csrc/utils/python_compat.h>
PyObject* guard_error_hook = NULL;
PyObject* guard_complete_hook = NULL;
typedef struct {
int active_dynamo_threads;
@ -626,6 +627,22 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
Py_RETURN_NONE;
}
static PyObject* set_guard_complete_hook(PyObject* dummy, PyObject* obj) {
PyObject* old_hook = guard_complete_hook;
if (obj == Py_None) {
obj = NULL;
}
guard_complete_hook = Py_XNewRef(obj);
if (old_hook == NULL) {
Py_RETURN_NONE;
} else {
return old_hook;
}
}
// Debugging function for GNU C only.
// Used to set gdb breakpoints in hot CPython sites from Python.
// Code example:
@ -666,6 +683,7 @@ static PyMethodDef _methods[] = {
{"unsupported", unsupported, METH_VARARGS, NULL},
{"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
{"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL},
{"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL},
{NULL, NULL, 0, NULL}};

View File

@ -7,6 +7,10 @@
#include <torch/csrc/dynamo/framelocals_mapping.h>
#include <torch/csrc/utils/python_compat.h>
extern "C" {
extern PyObject* guard_complete_hook;
}
static constexpr const char* cache_lookup_profiler_str =
"TorchDynamo Cache Lookup";
@ -197,7 +201,23 @@ PyObject* dynamo__custom_eval_frame(
// guard eval failed, keep propagating
fail();
return eval_result;
} else if (maybe_cached_code != Py_None) {
}
// NB: We only do guard collectives when there are any compiled code entries
// at all; these reduces overtriggering and we don't need to do guard
// collectives the very first time we've seen a frame
// TODO: We could also check if we had just created extra for the first
// time? Not too sure the best condition for extra->cache_entry_list
if (guard_complete_hook != nullptr && !extra->cache_entry_list.empty()) {
py::handle guard_complete_hook_handle(guard_complete_hook);
// False means force compilation (someone cache missed)
py::object res = guard_complete_hook_handle(maybe_cached_code != Py_None);
if (!py::cast<bool>(res)) {
maybe_cached_code = Py_None; // NB: non-owning
}
}
if (maybe_cached_code != Py_None) {
cached_code = (PyCodeObject*)maybe_cached_code;
// used cached version
DEBUG_TRACE("cache hit %s", get_frame_name(frame));