Compare commits

...

11 Commits

Author SHA1 Message Date
7a5382f662 Update
[ghstack-poisoned]
2025-10-30 07:46:14 -07:00
d9a5e6771b Update
[ghstack-poisoned]
2025-10-30 07:00:20 -07:00
9f8f34f235 Update
[ghstack-poisoned]
2025-10-29 15:08:42 -07:00
a6cc90a739 Update
[ghstack-poisoned]
2025-10-29 14:39:27 -07:00
e992700f8b Update
[ghstack-poisoned]
2025-10-29 12:01:37 -07:00
109d674520 Update
[ghstack-poisoned]
2025-10-29 11:22:12 -07:00
68fd4dedd4 Update
[ghstack-poisoned]
2025-10-29 10:23:09 -07:00
70ef2f215e Update
[ghstack-poisoned]
2025-10-29 09:56:12 -07:00
31a51046a3 Update
[ghstack-poisoned]
2025-10-29 09:53:37 -07:00
c2fbe8f2dd Update
[ghstack-poisoned]
2025-10-29 09:21:44 -07:00
b78bca6ecd Update (base update)
[ghstack-poisoned]
2025-10-29 09:21:44 -07:00
6 changed files with 142 additions and 6 deletions

View File

@ -1,6 +1,8 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <optional>
namespace c10 {
@ -19,7 +21,8 @@ struct C10_API AutogradState {
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode),
multithreading_enabled_(multithreading_enabled),
view_replay_enabled_(false) {}
view_replay_enabled_(false),
graph_exec_group_(std::nullopt) {}
void set_grad_mode(bool enabled) {
grad_mode_ = enabled;
@ -41,6 +44,10 @@ struct C10_API AutogradState {
view_replay_enabled_ = view_replay_enabled;
}
void set_graph_exec_group(std::optional<SafePyObject> group) {
graph_exec_group_ = std::move(group);
}
bool get_grad_mode() const {
return grad_mode_;
}
@ -61,6 +68,10 @@ struct C10_API AutogradState {
return view_replay_enabled_;
}
const std::optional<SafePyObject>& get_graph_exec_group() const {
return graph_exec_group_;
}
private:
bool grad_mode_ : 1;
bool inference_mode_ : 1;
@ -68,6 +79,7 @@ struct C10_API AutogradState {
bool multithreading_enabled_ : 1;
// NOLINTNEXTLINE(cppcoreguidelines-use-default-member-init)
bool view_replay_enabled_ : 1;
std::optional<SafePyObject> graph_exec_group_;
};
} // namespace c10

View File

@ -5223,6 +5223,7 @@ xfail_by_backend = {
"test_reentrant_with_callbacks_both_depths", # queue_callback
"test_reentrant_with_callbacks_depth_0", # queue_callback
"test_reentrant_with_callbacks_depth_1", # queue_callback
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_post_accumulate_grad_hook_ordering", # accuracy error

View File

@ -7362,6 +7362,60 @@ for shape in [(1,), ()]:
):
checkpoint_sequential(modules_list, 3, a)
@skipIfTorchDynamo("GraphExecGroup does not support compile")
def test_checkpoint_graph_execution_group(self):
def run(use_graph_execution_group):
counter = [0]
def fn(x):
counter[0] += 1
y = x.sin().cos()
z = y.sin().cos()
return y, z
x = torch.randn(3, 3, requires_grad=True)
y, z = checkpoint(fn, x, use_reentrant=False)
group = torch.utils.checkpoint.GraphExecGroup()
ctx = contextlib.nullcontext()
if use_graph_execution_group:
ctx = group
with ctx:
(grad_y,) = torch.autograd.grad(
z, inputs=(y,), grad_outputs=(torch.ones(3, 3),)
)
(grad_x,) = torch.autograd.grad(
y,
inputs=(x,),
grad_outputs=(grad_y,),
)
if use_graph_execution_group:
self.assertEqual(counter[0], 2)
else:
self.assertEqual(counter[0], 3)
run(use_graph_execution_group=True)
run(use_graph_execution_group=False)
# Test the not actually disjoint case (using retain_graph=True since
# otherwise autograd itself will catch this)
def fn(x):
return x.sin().cos()
x = torch.randn(3, 3, requires_grad=True)
out = checkpoint(fn, x, use_reentrant=False)
with torch.utils.checkpoint.GraphExecGroup():
# Under this context, we will enforce that two backward are disjoint
# even if retain_graph=True.
out.sum().backward(retain_graph=True)
with self.assertRaisesRegex(RuntimeError, "was already unpacked once"):
out.sum().backward()
def test_checkpoint_detects_non_determinism(self):
def save_3_tensors(x):
out = x.sin().exp()

View File

@ -67,6 +67,7 @@ from torch.types import (
Storage,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.checkpoint import GraphExecGroup
# This module is defined in torch/csrc/Module.cpp
@ -1488,6 +1489,8 @@ def _is_multithreading_enabled() -> _bool: ...
def _set_multithreading_enabled(enabled: _bool) -> None: ...
def _set_view_replay_enabled(enabled: _bool) -> None: ...
def _is_view_replay_enabled() -> _bool: ...
def _set_graph_exec_group(group: GraphExecGroup | None) -> None: ...
def _get_graph_exec_group() -> GraphExecGroup | None: ...
def _enter_dual_level() -> _int: ...
def _exit_dual_level(level: _int) -> None: ...
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...

View File

@ -1218,6 +1218,33 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
END_HANDLE_TH_ERRORS
}
static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) {
HANDLE_TH_ERRORS
if (obj == Py_None) {
c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt);
} else {
Py_INCREF(obj);
c10::AutogradState::get_tls_state().set_graph_exec_group(
c10::SafePyObject(obj, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* get_graph_exec_group(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
const auto& group =
c10::AutogradState::get_tls_state().get_graph_exec_group();
if (group.has_value()) {
PyObject* obj = group->ptr(getPyInterpreter());
Py_INCREF(obj);
return obj;
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (c10::InferenceMode::is_enabled()) {
@ -1598,6 +1625,8 @@ static PyMethodDef methods[] = {
castPyCFunctionWithKeywords(set_view_replay_enabled),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_set_graph_exec_group", set_graph_exec_group, METH_O, nullptr},
{"_get_graph_exec_group", get_graph_exec_group, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level",
castPyCFunctionWithKeywords(python_exit_dual_level),

View File

@ -32,6 +32,7 @@ __all__ = [
"SelectiveCheckpointContext",
"create_selective_checkpoint_contexts",
"SAC_IGNORED_OPS",
"GraphExecGroup",
]
_DEFAULT_DETERMINISM_MODE = "default"
@ -1069,7 +1070,7 @@ class _StopRecomputationError(Exception):
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, target_frame_ref: ReferenceType, gid: int):
def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]):
def pack_hook(x):
x = x.detach() if x.requires_grad else x
target_frame = target_frame_ref()
@ -1140,10 +1141,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
return holder
def unpack_hook(holder):
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
# First check if we're inside a GraphExecGroup context
gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group()
if gid is None:
# Fallback to using the current graph task id
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
if not frame.is_recomputed[gid]:
ctx = frame.input_saver.grad_fn
@ -1587,6 +1592,38 @@ def _checkpoint_without_reentrant_generator(
return
class GraphExecGroup:
"""Any checkpointed regions encountered by backward under the same instance
of this context manager will trigger recompute at most once, even if
there are multiple calls to backward.
Backward calls under the same instance of this context manager must execute
over non-overlapping regions of the backward graph even if retain_graph=True.
.. note::
This context manager only affects checkpoint with use_reentrant=False, and
is a no-op otherwise.
"""
def __enter__(self) -> "GraphExecGroup":
if torch._C._get_graph_exec_group() is not None:
raise RuntimeError(
"GraphExecGroup contexts cannot be nested. "
f"Already inside group {torch._C._get_graph_exec_group()}"
)
torch._C._set_graph_exec_group(self)
return self
def __exit__(self, *args: object) -> None:
torch._C._set_graph_exec_group(None)
@classmethod
def _get_current_group(cls) -> Optional["GraphExecGroup"]:
# Private API to be used by utils like AC
return torch._C._get_graph_exec_group()
# Note: [compiled autograd and checkpoint unpack hook]
# When tracing via compiled autograd, this hook will be visible to the
# compiler if the forward of this checkpointed region ran in eager.