Compare commits

...

14 Commits

Author SHA1 Message Date
1b9b8f52ae Update
[ghstack-poisoned]
2025-10-31 13:25:09 -07:00
0d559d0c20 Update
[ghstack-poisoned]
2025-10-31 12:12:11 -07:00
70714103b1 Update
[ghstack-poisoned]
2025-10-31 07:01:00 -07:00
5d7e730359 Update
[ghstack-poisoned]
2025-10-30 12:32:35 -07:00
363b1d2b49 Update
[ghstack-poisoned]
2025-10-30 09:33:33 -07:00
f278c43737 Update
[ghstack-poisoned]
2025-10-30 09:13:32 -07:00
2870894809 Update
[ghstack-poisoned]
2025-10-30 09:03:06 -07:00
cb0bb1d8bb Update
[ghstack-poisoned]
2025-10-30 08:53:03 -07:00
a1ee245e3e Update
[ghstack-poisoned]
2025-10-30 07:46:15 -07:00
66990f8dea Update (base update)
[ghstack-poisoned]
2025-10-30 07:46:15 -07:00
6913ecb72e Update
[ghstack-poisoned]
2025-10-30 07:00:22 -07:00
9266afcde2 Update (base update)
[ghstack-poisoned]
2025-10-30 07:00:22 -07:00
4708491c8d Update
[ghstack-poisoned]
2025-10-29 20:34:44 -07:00
a3d40e72f2 Update (base update)
[ghstack-poisoned]
2025-10-29 20:34:44 -07:00
11 changed files with 731 additions and 448 deletions

View File

@ -1,6 +1,8 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <optional>
namespace c10 {
@ -19,7 +21,8 @@ struct C10_API AutogradState {
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode),
multithreading_enabled_(multithreading_enabled),
view_replay_enabled_(false) {}
view_replay_enabled_(false),
graph_exec_group_(std::nullopt) {}
void set_grad_mode(bool enabled) {
grad_mode_ = enabled;
@ -41,6 +44,10 @@ struct C10_API AutogradState {
view_replay_enabled_ = view_replay_enabled;
}
void set_graph_exec_group(std::optional<SafePyObject> group) {
graph_exec_group_ = std::move(group);
}
bool get_grad_mode() const {
return grad_mode_;
}
@ -61,6 +68,10 @@ struct C10_API AutogradState {
return view_replay_enabled_;
}
const std::optional<SafePyObject>& get_graph_exec_group() const {
return graph_exec_group_;
}
private:
bool grad_mode_ : 1;
bool inference_mode_ : 1;
@ -68,6 +79,7 @@ struct C10_API AutogradState {
bool multithreading_enabled_ : 1;
// NOLINTNEXTLINE(cppcoreguidelines-use-default-member-init)
bool view_replay_enabled_ : 1;
std::optional<SafePyObject> graph_exec_group_;
};
} // namespace c10

File diff suppressed because it is too large Load Diff

View File

@ -2625,7 +2625,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b)
return FwBwMutation.apply(a, b).sin_().clone()
inps = [
torch.ones(3, 3, requires_grad=True),
@ -2674,17 +2674,22 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""",
clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, tangents_1):
def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_2, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -5223,6 +5223,7 @@ xfail_by_backend = {
"test_reentrant_with_callbacks_both_depths", # queue_callback
"test_reentrant_with_callbacks_depth_0", # queue_callback
"test_reentrant_with_callbacks_depth_1", # queue_callback
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_post_accumulate_grad_hook_ordering", # accuracy error

View File

@ -7362,6 +7362,60 @@ for shape in [(1,), ()]:
):
checkpoint_sequential(modules_list, 3, a)
@skipIfTorchDynamo("GraphExecGroup does not support compile")
def test_checkpoint_graph_execution_group(self):
def run(use_graph_execution_group):
counter = [0]
def fn(x):
counter[0] += 1
y = x.sin().cos()
z = y.sin().cos()
return y, z
x = torch.randn(3, 3, requires_grad=True)
y, z = checkpoint(fn, x, use_reentrant=False)
group = torch.utils.checkpoint.GraphExecGroup()
ctx = contextlib.nullcontext()
if use_graph_execution_group:
ctx = group
with ctx:
(grad_y,) = torch.autograd.grad(
z, inputs=(y,), grad_outputs=(torch.ones(3, 3),)
)
(grad_x,) = torch.autograd.grad(
y,
inputs=(x,),
grad_outputs=(grad_y,),
)
if use_graph_execution_group:
self.assertEqual(counter[0], 2)
else:
self.assertEqual(counter[0], 3)
run(use_graph_execution_group=True)
run(use_graph_execution_group=False)
# Test the not actually disjoint case (using retain_graph=True since
# otherwise autograd itself will catch this)
def fn(x):
return x.sin().cos()
x = torch.randn(3, 3, requires_grad=True)
out = checkpoint(fn, x, use_reentrant=False)
with torch.utils.checkpoint.GraphExecGroup():
# Under this context, we will enforce that two backward are disjoint
# even if retain_graph=True.
out.sum().backward(retain_graph=True)
with self.assertRaisesRegex(RuntimeError, "was already unpacked once"):
out.sum().backward()
def test_checkpoint_detects_non_determinism(self):
def save_3_tensors(x):
out = x.sin().exp()

View File

@ -67,6 +67,7 @@ from torch.types import (
Storage,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.checkpoint import GraphExecGroup
# This module is defined in torch/csrc/Module.cpp
@ -1488,6 +1489,8 @@ def _is_multithreading_enabled() -> _bool: ...
def _set_multithreading_enabled(enabled: _bool) -> None: ...
def _set_view_replay_enabled(enabled: _bool) -> None: ...
def _is_view_replay_enabled() -> _bool: ...
def _set_graph_exec_group(group: GraphExecGroup | None) -> None: ...
def _get_graph_exec_group() -> GraphExecGroup | None: ...
def _enter_dual_level() -> _int: ...
def _exit_dual_level(level: _int) -> None: ...
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...

View File

@ -289,11 +289,15 @@ def create_joint(
]:
outs_descs = None
if primals_descs is None:
outs, tangent_mask = fn(*primals)
with set_partitioner_tag_is_forward():
outs, tangent_mask = fn(*primals)
assert not pytree.tree_any(lambda x: isinstance(x, AOTOutput), tangent_mask)
else:
def fn_wrapped(*args, **kwargs):
with set_partitioner_tag_is_forward():
return fn(*args, **kwargs)
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals
fn_wrapped, primals
)
# TODO: I think this hook can also be eliminated now
@ -559,6 +563,10 @@ def set_partitioner_tag_is_backward():
return set_partitioner_tag("is_backward")
def set_partitioner_tag_is_forward():
return set_partitioner_tag("is_forward")
def set_partitioner_tag_must_be_in_backward():
return set_partitioner_tag("must_be_in_backward")

View File

@ -51,6 +51,7 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import assert_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -287,6 +288,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1011,105 +1016,89 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
# In the min-cut partitioner, we run CSE, but we don't have non-pure aware CSE
# Maybe we could conditionally run if the graph is functional.
# fx_g = joint_module.graph
# if config.cse:
# assert_functional_graph(joint_module.graph)
# cse_graph = fx_graph_cse(fx_g)
# joint_module.graph = cse_graph
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
# Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
node.name for node in forward_nodes if node.op != "output"
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
assert_functional_graph(joint_module.graph)
joint_module = cleanup_recompute_tags(joint_module)
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif (
continue
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
saved_values.append(node)
continue
if node.is_impure(impure_random=False) and node.op not in ("placeholder", "output"):
# See is_impure in torch/fx/node.py
assert not graph_has_recomputable_ops, (
"Trying to apply AC on a graph with impure op", node, node.target
)
saved_values.append(node)
continue
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if (
"tensor_meta" in node.meta
and all(is_sym_node(n) for n in backward_usages)
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
continue
if (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target == operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if "tensor_meta" in node.meta and all(
is_sym_node(n) for n in backward_usages
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
assert all(user.target == operator.getitem for user in node.users)
continue
if not must_recompute(node):
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
return _extract_fwd_bwd_modules(
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
@ -1117,6 +1106,25 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
# Failing for python test/functorch/test_aotdispatch.py -k test_some_output_requires_grad_input_doesnt
# last_input = next(reversed(module.graph.find_nodes(op="placeholder")) # StopIteration
# fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
# bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6)
@ -1650,6 +1658,9 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
else:
if any(must_recompute(user) for user in node.users):
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module

View File

@ -1218,6 +1218,33 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
END_HANDLE_TH_ERRORS
}
static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) {
HANDLE_TH_ERRORS
if (obj == Py_None) {
c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt);
} else {
Py_INCREF(obj);
c10::AutogradState::get_tls_state().set_graph_exec_group(
c10::SafePyObject(obj, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* get_graph_exec_group(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
const auto& group =
c10::AutogradState::get_tls_state().get_graph_exec_group();
if (group.has_value()) {
PyObject* obj = group->ptr(getPyInterpreter());
Py_INCREF(obj);
return obj;
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (c10::InferenceMode::is_enabled()) {
@ -1598,6 +1625,8 @@ static PyMethodDef methods[] = {
castPyCFunctionWithKeywords(set_view_replay_enabled),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_set_graph_exec_group", set_graph_exec_group, METH_O, nullptr},
{"_get_graph_exec_group", get_graph_exec_group, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level",
castPyCFunctionWithKeywords(python_exit_dual_level),

View File

@ -349,6 +349,10 @@ def set_current_meta(node, pass_name=""):
current_meta["from_node"] = [
NodeSource(node, pass_name, NodeSourceAction.CREATE)
]
# We need something like this because the AC HOP runs with a fx.Interpreter
# overriding the partitioner_tag context in the outer scope.
if saved_meta.get("partitioner_tag") is not None:
current_meta["partitioner_tag"] = saved_meta["partitioner_tag"]
yield
finally:
current_meta = saved_meta

View File

@ -32,6 +32,7 @@ __all__ = [
"SelectiveCheckpointContext",
"create_selective_checkpoint_contexts",
"SAC_IGNORED_OPS",
"GraphExecGroup",
]
_DEFAULT_DETERMINISM_MODE = "default"
@ -1069,7 +1070,7 @@ class _StopRecomputationError(Exception):
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, target_frame_ref: ReferenceType, gid: int):
def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]):
def pack_hook(x):
x = x.detach() if x.requires_grad else x
target_frame = target_frame_ref()
@ -1140,10 +1141,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
return holder
def unpack_hook(holder):
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
# First check if we're inside a GraphExecGroup context
gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group()
if gid is None:
# Fallback to using the current graph task id
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
if not frame.is_recomputed[gid]:
ctx = frame.input_saver.grad_fn
@ -1587,6 +1592,38 @@ def _checkpoint_without_reentrant_generator(
return
class GraphExecGroup:
"""Any checkpointed regions encountered by backward under the same instance
of this context manager will trigger recompute at most once, even if
there are multiple calls to backward.
Backward calls under the same instance of this context manager must execute
over non-overlapping regions of the backward graph even if retain_graph=True.
.. note::
This context manager only affects checkpoint with use_reentrant=False, and
is a no-op otherwise.
"""
def __enter__(self) -> "GraphExecGroup":
if torch._C._get_graph_exec_group() is not None:
raise RuntimeError(
"GraphExecGroup contexts cannot be nested. "
f"Already inside group {torch._C._get_graph_exec_group()}"
)
torch._C._set_graph_exec_group(self)
return self
def __exit__(self, *args: object) -> None:
torch._C._set_graph_exec_group(None)
@classmethod
def _get_current_group(cls) -> Optional["GraphExecGroup"]:
# Private API to be used by utils like AC
return torch._C._get_graph_exec_group()
# Note: [compiled autograd and checkpoint unpack hook]
# When tracing via compiled autograd, this hook will be visible to the
# compiler if the forward of this checkpointed region ran in eager.