mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement guard collectives (optimized version) (#156562)
This is a remix of https://github.com/pytorch/pytorch/pull/155558 Instead of mediating guard collective via a config option, in this one it's done via a `set_stance` like API. The motivation is that checking for the config value on entry on torch.compile is apparently quite expensive, according to functorch_maml_omniglot. So this makes it a bit cheaper. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/156562 Approved by: https://github.com/Microve
This commit is contained in:
committed by
PyTorch MergeBot
parent
73772919d2
commit
17eb649d55
@ -21,6 +21,7 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
|||||||
list_backends
|
list_backends
|
||||||
disable
|
disable
|
||||||
set_stance
|
set_stance
|
||||||
|
set_enable_guard_collectives
|
||||||
cudagraph_mark_step_begin
|
cudagraph_mark_step_begin
|
||||||
is_compiling
|
is_compiling
|
||||||
is_dynamo_compiling
|
is_dynamo_compiling
|
||||||
|
@ -25,6 +25,7 @@ from torch._dynamo.comptime import comptime
|
|||||||
from torch._dynamo.testing import collect_results
|
from torch._dynamo.testing import collect_results
|
||||||
from torch._dynamo.utils import same
|
from torch._dynamo.utils import same
|
||||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||||
|
from torch.compiler import set_enable_guard_collectives
|
||||||
from torch.distributed._functional_collectives import _maybe_wrap_tensor
|
from torch.distributed._functional_collectives import _maybe_wrap_tensor
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp.wrap import (
|
from torch.distributed.fsdp.wrap import (
|
||||||
@ -61,6 +62,15 @@ def init_weights(m):
|
|||||||
m.bias.data.fill_(0.01)
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def enable_guard_collectives():
|
||||||
|
old = set_enable_guard_collectives(True)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
set_enable_guard_collectives(old)
|
||||||
|
|
||||||
|
|
||||||
class ToyModel(nn.Module):
|
class ToyModel(nn.Module):
|
||||||
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
|
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1141,6 +1151,31 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
for r in res[1:]:
|
for r in res[1:]:
|
||||||
self.assertEqual(res[0], r)
|
self.assertEqual(res[0], r)
|
||||||
|
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
@enable_guard_collectives()
|
||||||
|
def test_guard_collective(self):
|
||||||
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||||
|
torch._dynamo.utils.clear_compilation_metrics()
|
||||||
|
|
||||||
|
@torch.compile()
|
||||||
|
def f(x):
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
x = torch.randn(10, device=self.rank)
|
||||||
|
f(x)
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
x = torch.randn(10, device=self.rank)
|
||||||
|
else:
|
||||||
|
x = torch.randn(12, device=self.rank) # recompile on one rank
|
||||||
|
f(x)
|
||||||
|
|
||||||
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
||||||
|
res = [None] * self.world_size
|
||||||
|
torch.distributed.all_gather_object(res, len(metrics))
|
||||||
|
for r in res[1:]:
|
||||||
|
self.assertEqual(res[0], r)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
def test_get_pg_attr(self):
|
def test_get_pg_attr(self):
|
||||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
import enum
|
import enum
|
||||||
import types
|
import types
|
||||||
from typing import overload
|
from typing import Optional, overload
|
||||||
|
|
||||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn
|
from torch._dynamo.types import (
|
||||||
|
DynamoCallback,
|
||||||
|
DynamoGuardCompleteHook,
|
||||||
|
DynamoGuardHook,
|
||||||
|
GuardFn,
|
||||||
|
)
|
||||||
|
|
||||||
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
||||||
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
|
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
|
||||||
@ -13,6 +18,9 @@ def set_code_exec_strategy(
|
|||||||
code: types.CodeType, strategy: _FrameExecStrategy
|
code: types.CodeType, strategy: _FrameExecStrategy
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
||||||
|
def set_guard_complete_hook(
|
||||||
|
hook: Optional[DynamoGuardCompleteHook],
|
||||||
|
) -> Optional[DynamoGuardCompleteHook]: ...
|
||||||
def raise_sigtrap() -> None: ...
|
def raise_sigtrap() -> None: ...
|
||||||
|
|
||||||
class _CacheEntry:
|
class _CacheEntry:
|
||||||
|
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
|||||||
from torch._C._dynamo.eval_frame import ( # noqa: F401
|
from torch._C._dynamo.eval_frame import ( # noqa: F401
|
||||||
reset_code,
|
reset_code,
|
||||||
set_eval_frame,
|
set_eval_frame,
|
||||||
|
set_guard_complete_hook,
|
||||||
set_guard_error_hook,
|
set_guard_error_hook,
|
||||||
unsupported,
|
unsupported,
|
||||||
)
|
)
|
||||||
|
@ -22,6 +22,7 @@ from . import config
|
|||||||
|
|
||||||
|
|
||||||
_COMPILE_PG: Optional[dist.ProcessGroup] = None
|
_COMPILE_PG: Optional[dist.ProcessGroup] = None
|
||||||
|
_GUARD_PG: Optional[dist.ProcessGroup] = None
|
||||||
|
|
||||||
|
|
||||||
def get_compile_pg() -> Optional[dist.ProcessGroup]:
|
def get_compile_pg() -> Optional[dist.ProcessGroup]:
|
||||||
@ -39,3 +40,15 @@ def get_compile_pg() -> Optional[dist.ProcessGroup]:
|
|||||||
return _COMPILE_PG
|
return _COMPILE_PG
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# NB: Unlike get_compile_pg, this is only called when guard collectives were
|
||||||
|
# explicitly requested
|
||||||
|
def get_guard_pg() -> Optional[dist.ProcessGroup]:
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
global _GUARD_PG
|
||||||
|
if _GUARD_PG is None:
|
||||||
|
_GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg")
|
||||||
|
return _GUARD_PG
|
||||||
|
|
||||||
|
return None
|
||||||
|
@ -58,6 +58,7 @@ from torch._C._dynamo.eval_frame import ( # noqa: F401
|
|||||||
reset_code,
|
reset_code,
|
||||||
set_code_exec_strategy,
|
set_code_exec_strategy,
|
||||||
set_eval_frame,
|
set_eval_frame,
|
||||||
|
set_guard_complete_hook,
|
||||||
set_guard_error_hook,
|
set_guard_error_hook,
|
||||||
set_skip_guard_eval_unsafe,
|
set_skip_guard_eval_unsafe,
|
||||||
unsupported,
|
unsupported,
|
||||||
@ -90,7 +91,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||||||
)
|
)
|
||||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||||
|
|
||||||
from . import config, convert_frame, external_utils, trace_rules, utils
|
from . import config, convert_frame, distributed, external_utils, trace_rules, utils
|
||||||
from .backends.registry import CompilerFn, lookup_backend
|
from .backends.registry import CompilerFn, lookup_backend
|
||||||
from .code_context import code_context
|
from .code_context import code_context
|
||||||
from .exc import (
|
from .exc import (
|
||||||
@ -519,6 +520,38 @@ def _log_traced_frames():
|
|||||||
log.info(msg)
|
log.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def guard_collectives_hook(guard_eval_result):
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch._dynamo.utils import dynamo_timed
|
||||||
|
|
||||||
|
# guard_eval_result == True ==> cache hit
|
||||||
|
if pg := distributed.get_guard_pg():
|
||||||
|
with dynamo_timed(
|
||||||
|
"guard_collective", log_pt2_compile_event=True, log_waitcounter=True
|
||||||
|
):
|
||||||
|
log.info("guard_collective %s", guard_eval_result)
|
||||||
|
torch._logging.trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "guard_collective",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: str(guard_eval_result),
|
||||||
|
)
|
||||||
|
# TODO: a bit awkward to time, this isn't inside of the dynamo compile region
|
||||||
|
all_results = [None] * pg.size()
|
||||||
|
dist.all_gather_object(all_results, guard_eval_result, group=pg)
|
||||||
|
# True = everyone hit, OK to run
|
||||||
|
# False = someone missed, force recompile everywhere
|
||||||
|
res = all(all_results)
|
||||||
|
log.info("guard_collective %s -> %s", guard_eval_result, res)
|
||||||
|
return res
|
||||||
|
return guard_eval_result
|
||||||
|
|
||||||
|
|
||||||
|
_not_set = object()
|
||||||
|
|
||||||
|
|
||||||
class _TorchDynamoContext:
|
class _TorchDynamoContext:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -114,6 +114,13 @@ class DynamoGuardHook(Protocol):
|
|||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DynamoGuardCompleteHook(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
cache_hit: bool,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
class ProfilerStartHook(Protocol):
|
class ProfilerStartHook(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -21,6 +21,7 @@ __all__ = [
|
|||||||
"list_backends",
|
"list_backends",
|
||||||
"disable",
|
"disable",
|
||||||
"set_stance",
|
"set_stance",
|
||||||
|
"set_enable_guard_collectives",
|
||||||
"cudagraph_mark_step_begin",
|
"cudagraph_mark_step_begin",
|
||||||
"wrap_numpy",
|
"wrap_numpy",
|
||||||
"is_compiling",
|
"is_compiling",
|
||||||
@ -330,6 +331,35 @@ def set_stance(
|
|||||||
set_stance._dynamo_forbidden = True # type: ignore[attr-defined]
|
set_stance._dynamo_forbidden = True # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def set_enable_guard_collectives(enabled: bool):
|
||||||
|
"""
|
||||||
|
Enables use of collectives *during* guard evaluation to synchronize behavior
|
||||||
|
across ranks. This is expensive: we have to issue a collective every time
|
||||||
|
we enter a compiled code region, even if no rank actually would need to
|
||||||
|
compile. This can help prevent NCCL hangs by ensuring that we never have a
|
||||||
|
situation where one rank starts recompiling while other ranks don't compile;
|
||||||
|
it is especially useful in conjunction with enable_compiler_collectives
|
||||||
|
where such a situation would immediately cause a hang (as it is necessary
|
||||||
|
for all ranks to compile at the same time to run compiler collectives). Like
|
||||||
|
compiler collectives, you can only run this on SPMD programs; you will hang
|
||||||
|
otherwise. Note that a guard collective is only issued if there is any
|
||||||
|
compiled code to guard on; if this the first time we encounter a frame or
|
||||||
|
the frame is skipped, we don't issue collectives.
|
||||||
|
|
||||||
|
Returns the previous setting of enabled.
|
||||||
|
"""
|
||||||
|
from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401
|
||||||
|
from torch._dynamo.eval_frame import guard_collectives_hook
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
return set_guard_complete_hook(guard_collectives_hook) is not None
|
||||||
|
else:
|
||||||
|
return set_guard_complete_hook(None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def cudagraph_mark_step_begin():
|
def cudagraph_mark_step_begin():
|
||||||
"""
|
"""
|
||||||
Indicates that a new iteration of inference or training is about to begin.
|
Indicates that a new iteration of inference or training is about to begin.
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include <torch/csrc/utils/python_compat.h>
|
#include <torch/csrc/utils/python_compat.h>
|
||||||
|
|
||||||
PyObject* guard_error_hook = NULL;
|
PyObject* guard_error_hook = NULL;
|
||||||
|
PyObject* guard_complete_hook = NULL;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int active_dynamo_threads;
|
int active_dynamo_threads;
|
||||||
@ -626,6 +627,22 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
|
|||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* set_guard_complete_hook(PyObject* dummy, PyObject* obj) {
|
||||||
|
PyObject* old_hook = guard_complete_hook;
|
||||||
|
|
||||||
|
if (obj == Py_None) {
|
||||||
|
obj = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
guard_complete_hook = Py_XNewRef(obj);
|
||||||
|
|
||||||
|
if (old_hook == NULL) {
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
} else {
|
||||||
|
return old_hook;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Debugging function for GNU C only.
|
// Debugging function for GNU C only.
|
||||||
// Used to set gdb breakpoints in hot CPython sites from Python.
|
// Used to set gdb breakpoints in hot CPython sites from Python.
|
||||||
// Code example:
|
// Code example:
|
||||||
@ -666,6 +683,7 @@ static PyMethodDef _methods[] = {
|
|||||||
{"unsupported", unsupported, METH_VARARGS, NULL},
|
{"unsupported", unsupported, METH_VARARGS, NULL},
|
||||||
{"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL},
|
{"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL},
|
||||||
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
|
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
|
||||||
|
{"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL},
|
||||||
{"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL},
|
{"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL},
|
||||||
{NULL, NULL, 0, NULL}};
|
{NULL, NULL, 0, NULL}};
|
||||||
|
|
||||||
|
@ -7,6 +7,10 @@
|
|||||||
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
||||||
#include <torch/csrc/utils/python_compat.h>
|
#include <torch/csrc/utils/python_compat.h>
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
extern PyObject* guard_complete_hook;
|
||||||
|
}
|
||||||
|
|
||||||
static constexpr const char* cache_lookup_profiler_str =
|
static constexpr const char* cache_lookup_profiler_str =
|
||||||
"TorchDynamo Cache Lookup";
|
"TorchDynamo Cache Lookup";
|
||||||
|
|
||||||
@ -197,7 +201,23 @@ PyObject* dynamo__custom_eval_frame(
|
|||||||
// guard eval failed, keep propagating
|
// guard eval failed, keep propagating
|
||||||
fail();
|
fail();
|
||||||
return eval_result;
|
return eval_result;
|
||||||
} else if (maybe_cached_code != Py_None) {
|
}
|
||||||
|
|
||||||
|
// NB: We only do guard collectives when there are any compiled code entries
|
||||||
|
// at all; these reduces overtriggering and we don't need to do guard
|
||||||
|
// collectives the very first time we've seen a frame
|
||||||
|
// TODO: We could also check if we had just created extra for the first
|
||||||
|
// time? Not too sure the best condition for extra->cache_entry_list
|
||||||
|
if (guard_complete_hook != nullptr && !extra->cache_entry_list.empty()) {
|
||||||
|
py::handle guard_complete_hook_handle(guard_complete_hook);
|
||||||
|
// False means force compilation (someone cache missed)
|
||||||
|
py::object res = guard_complete_hook_handle(maybe_cached_code != Py_None);
|
||||||
|
if (!py::cast<bool>(res)) {
|
||||||
|
maybe_cached_code = Py_None; // NB: non-owning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maybe_cached_code != Py_None) {
|
||||||
cached_code = (PyCodeObject*)maybe_cached_code;
|
cached_code = (PyCodeObject*)maybe_cached_code;
|
||||||
// used cached version
|
// used cached version
|
||||||
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
|
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
|
||||||
|
Reference in New Issue
Block a user