Compare commits

...

31 Commits

Author SHA1 Message Date
19fc2d35d7 Update
[ghstack-poisoned]
2025-11-13 09:25:41 -08:00
8a02e7fd14 Update
[ghstack-poisoned]
2025-11-12 16:01:57 -08:00
eea6b6f0e3 Update
[ghstack-poisoned]
2025-11-07 11:56:50 -08:00
b2f6928f34 Update (base update)
[ghstack-poisoned]
2025-11-07 11:56:50 -08:00
09bb739aa3 Update
[ghstack-poisoned]
2025-11-07 10:14:36 -08:00
983b20821d Update (base update)
[ghstack-poisoned]
2025-11-07 10:14:36 -08:00
dede07ca26 Update
[ghstack-poisoned]
2025-11-07 09:07:32 -08:00
eecc531f81 Update
[ghstack-poisoned]
2025-11-06 14:41:05 -08:00
ca7270992c Update (base update)
[ghstack-poisoned]
2025-11-06 12:31:11 -08:00
eb5947d25f Update
[ghstack-poisoned]
2025-11-06 12:31:11 -08:00
1b94305bd0 Update
[ghstack-poisoned]
2025-11-04 15:13:55 -08:00
c1e0a9318d Update
[ghstack-poisoned]
2025-11-04 14:57:06 -08:00
5c95ba1705 Update
[ghstack-poisoned]
2025-11-04 14:17:16 -08:00
fdbf5935c2 Update (base update)
[ghstack-poisoned]
2025-11-03 13:10:51 -08:00
be9493cbd8 Update
[ghstack-poisoned]
2025-11-03 13:10:51 -08:00
10b2b1a8bc Update
[ghstack-poisoned]
2025-11-03 12:53:29 -08:00
46eec57fbc Update
[ghstack-poisoned]
2025-11-03 12:50:49 -08:00
1b9b8f52ae Update
[ghstack-poisoned]
2025-10-31 13:25:09 -07:00
0d559d0c20 Update
[ghstack-poisoned]
2025-10-31 12:12:11 -07:00
70714103b1 Update
[ghstack-poisoned]
2025-10-31 07:01:00 -07:00
5d7e730359 Update
[ghstack-poisoned]
2025-10-30 12:32:35 -07:00
363b1d2b49 Update
[ghstack-poisoned]
2025-10-30 09:33:33 -07:00
f278c43737 Update
[ghstack-poisoned]
2025-10-30 09:13:32 -07:00
2870894809 Update
[ghstack-poisoned]
2025-10-30 09:03:06 -07:00
cb0bb1d8bb Update
[ghstack-poisoned]
2025-10-30 08:53:03 -07:00
66990f8dea Update (base update)
[ghstack-poisoned]
2025-10-30 07:46:15 -07:00
a1ee245e3e Update
[ghstack-poisoned]
2025-10-30 07:46:15 -07:00
9266afcde2 Update (base update)
[ghstack-poisoned]
2025-10-30 07:00:22 -07:00
6913ecb72e Update
[ghstack-poisoned]
2025-10-30 07:00:22 -07:00
a3d40e72f2 Update (base update)
[ghstack-poisoned]
2025-10-29 20:34:44 -07:00
4708491c8d Update
[ghstack-poisoned]
2025-10-29 20:34:44 -07:00
13 changed files with 619 additions and 212 deletions

View File

@ -1,6 +1,8 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <optional>
namespace c10 {
@ -15,7 +17,8 @@ struct C10_API AutogradState {
bool inference_mode,
bool fw_grad_mode,
bool multithreading_enabled)
: grad_mode_(grad_mode),
: graph_exec_group_(std::nullopt),
grad_mode_(grad_mode),
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode),
multithreading_enabled_(multithreading_enabled),
@ -41,6 +44,10 @@ struct C10_API AutogradState {
view_replay_enabled_ = view_replay_enabled;
}
void set_graph_exec_group(std::optional<SafePyObject> group) {
graph_exec_group_ = std::move(group);
}
bool get_grad_mode() const {
return grad_mode_;
}
@ -61,7 +68,12 @@ struct C10_API AutogradState {
return view_replay_enabled_;
}
const std::optional<SafePyObject>& get_graph_exec_group() const {
return graph_exec_group_;
}
private:
std::optional<SafePyObject> graph_exec_group_;
bool grad_mode_ : 1;
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
@ -372,7 +371,6 @@ class DTensorExportTest(TestCase):
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)

View File

@ -13,12 +13,12 @@ import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from functorch.compile import default_partition, min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -275,7 +275,14 @@ class ActivationCheckpointingViaTagsTests(
run(export_compiler)
def test_tags_function(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -291,11 +298,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_via_global_checkpoint(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -310,17 +328,28 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_with_kwargs(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_with_kwargs(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
gn, torch.sin(x), y, use_reentrant=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
@ -330,11 +359,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_sequential_layers(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_sequential_layers(self, device, partition_fn):
def gn(x):
x = x.cos()
for _ in range(3):
@ -355,11 +395,22 @@ class ActivationCheckpointingViaTagsTests(
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_multiple_checkpoints(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_multiple_checkpoints(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -377,11 +428,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_module(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_module(self, device, partition_fn):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -405,11 +467,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_decomps(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_decomps(self, device, partition_fn):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
@ -437,6 +510,7 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
@ -696,7 +770,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_recompute(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
@ -717,9 +798,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def _test(context_fn, bw_compiler):
def _test(context_fn, bw_compiler, partition_fn):
def gn(x):
return torch.sigmoid(torch.matmul(x, x))
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
@ -733,14 +814,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
fw_compiler = functools.partial(
count_ops,
freq=1,
freq=2,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@ -748,17 +829,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=2, # 2 bwd mm ops per fwd matmul
freq=4, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
def test_sac_with_partial_context_fn(self):
@ -795,7 +878,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm(
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -835,15 +927,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
@ -883,7 +982,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
@ -891,7 +990,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_triton_kernel(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
# Copy of the above test, but make sure that having a triton kernel in the
# region does not error.
def add_one(x):
@ -951,14 +1057,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_tensor_subclass(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1001,14 +1114,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_custom_rule(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1066,14 +1186,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
@ -1112,14 +1239,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_outplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1157,14 +1291,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
@ -1200,7 +1341,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1211,7 +1352,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda_and_triton
def test_compile_selective_checkpoint_inplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1251,7 +1399,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1259,7 +1407,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
@ -1306,7 +1461,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
@ -1318,7 +1473,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_invalid_context(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
@ -1347,7 +1509,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
@ -1356,7 +1518,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_compile_selective_checkpoint_parametrization(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
@ -1419,7 +1588,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bw_compiler = functools.partial(
count_ops,
freqs=[
2, # 1 from mul recompute, 1 from mul backward
# 1 from mul recompute, 1 from mul backward
# w/o CSE, we have one extra mul
3 if partition_fn is default_partition else 2,
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
@ -1428,7 +1599,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
model = MLPModule()

View File

@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b)
return FwBwMutation.apply(a, b).sin_().clone()
inps = [
torch.ones(3, 3, requires_grad=True),
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""",
clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, tangents_1):
def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_2, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
op="call_function", target=torch.ops.aten.mm.default
)
self.assertEqual(len(mm_nodes), 4)
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)

View File

@ -5222,6 +5222,7 @@ xfail_by_backend = {
"test_reentrant_with_callbacks_both_depths", # queue_callback
"test_reentrant_with_callbacks_depth_0", # queue_callback
"test_reentrant_with_callbacks_depth_1", # queue_callback
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_post_accumulate_grad_hook_ordering", # accuracy error

View File

@ -7364,6 +7364,62 @@ for shape in [(1,), ()]:
):
checkpoint_sequential(modules_list, 3, a)
@skipIfTorchDynamo("GraphExecGroup does not support compile")
def test_checkpoint_graph_execution_group(self):
def run(use_graph_execution_group):
counter = [0]
def fn(x):
counter[0] += 1
y = x.sin().cos()
z = y.sin().cos()
return y, z
x = torch.randn(3, 3, requires_grad=True)
y, z = checkpoint(fn, x, use_reentrant=False)
group = torch.utils.checkpoint.GraphExecGroup()
ctx = contextlib.nullcontext()
if use_graph_execution_group:
ctx = group
with ctx:
(grad_y,) = torch.autograd.grad(
z, inputs=(y,), grad_outputs=(torch.ones(3, 3),)
)
(grad_x,) = torch.autograd.grad(
y,
inputs=(x,),
grad_outputs=(grad_y,),
)
if use_graph_execution_group:
self.assertEqual(counter[0], 2)
else:
self.assertEqual(counter[0], 3)
run(use_graph_execution_group=True)
run(use_graph_execution_group=False)
# Test the not actually disjoint case (using retain_graph=True since
# otherwise autograd itself will catch this)
def fn(x):
return x.sin().cos()
x = torch.randn(3, 3, requires_grad=True)
out = checkpoint(fn, x, use_reentrant=False)
with torch.utils.checkpoint.GraphExecGroup():
# Under this context, we will enforce that two backward are disjoint
# even if retain_graph=True.
out.sum().backward(retain_graph=True)
with self.assertRaisesRegex(
RuntimeError, "Performing two backward calls that overlap"
):
out.sum().backward()
def test_checkpoint_detects_non_determinism(self):
def save_3_tensors(x):
out = x.sin().exp()

View File

@ -69,6 +69,7 @@ from torch.types import (
Storage,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.checkpoint import GraphExecGroup
# This module is defined in torch/csrc/Module.cpp
@ -1491,6 +1492,8 @@ def _is_multithreading_enabled() -> _bool: ...
def _set_multithreading_enabled(enabled: _bool) -> None: ...
def _set_view_replay_enabled(enabled: _bool) -> None: ...
def _is_view_replay_enabled() -> _bool: ...
def _set_graph_exec_group(group: GraphExecGroup | None) -> None: ...
def _get_graph_exec_group() -> GraphExecGroup | None: ...
def _enter_dual_level() -> _int: ...
def _exit_dual_level(level: _int) -> None: ...
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...

View File

@ -10,6 +10,7 @@ This file contains utilities related to functionalization in AOTAutograd:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor
@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg):
# Returns the number of detected copy_
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]:
allowed_mutation_ops = [
torch.ops.aten.copy_.default,
torch.ops.aten.set_.source_Tensor,
@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
# NB: It would also be nice to verify that the mutations all happen at the
# end, but we also do some administrative views after mutations so this
# isn't actually true. (TODO: Could this cause problems for Inductor?)
error = None
for n in fx_g.nodes:
if n.op == "placeholder":
placeholders.add(n)
@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
# this is mostly a hack to avoid failing XLA tests.
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
if "set_buffer_donor_" not in str(n.args[0]):
assert n.args[0] in placeholders, (
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
)
if n.args[0] not in placeholders:
error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
mutation_count += 1
else:
assert not n.target._schema.is_mutable, (
f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
)
if n.target._schema.is_mutable:
error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
return error, mutation_count
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
error, mutation_count = _is_functional_graph(fx_g)
assert error is None, error
return mutation_count

View File

@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker,
get_proxy_mode,
maybe_disable_thunkify,
maybe_enable_thunkify,
)
@ -295,6 +296,10 @@ def create_joint(
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals
)
mode = get_proxy_mode()
assert mode is not None, "Expected non-None proxy mode"
for node in mode.tracer.graph.nodes:
node.meta["partitioner_tag"] = "is_forward"
# TODO: I think this hook can also be eliminated now
if joint_fn_handle and joint_fn_handle.post_forward:

View File

@ -10,6 +10,7 @@ import operator
import os
import os.path
import re
import warnings
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass, replace
@ -51,6 +52,7 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import _is_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -297,6 +299,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1021,105 +1027,132 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
# Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
node.name for node in forward_nodes if node.op != "output"
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
if _is_functional_graph(joint_module.graph)[0] is not None:
# Fall-back to previous behavior to avoid bc-breaking, although can
# eventually flip the switch to make this a hard error.
warnings.warn(
"Trying to unsafely apply AC to a non-functional graph with the "
"default partitioner. Falling back to min-cut partitioner."
)
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
def is_tensor(node):
# This node returns a single tensor output
return (
"tensor_meta" in node.meta
and node.op == "call_function"
and isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
)
def is_multi_output(node):
return (
not is_tensor(node)
and all(user.target == operator.getitem for user in node.users)
and len(node.users) > 0
)
def is_impure(node):
# wait tensor is an "impure" op according to DCE's definition of impure
# (see is_impure in torch/fx/node.py), but it survives past
# functionalization and can be safely dup'd and reordered under the
# assumption SPMD.
return (
node.is_impure(impure_random=False)
and node.op
not in (
"placeholder",
"output",
)
and node.target is not torch.ops._c10d_functional.wait_tensor.default
)
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target is operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if "tensor_meta" in node.meta and all(
is_sym_node(n) for n in backward_usages
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
continue
if is_multi_output(node):
# Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE.
continue
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
saved_values.append(node)
continue
if is_impure(node):
assert not graph_has_recomputable_ops, (
"Trying to apply AC on a graph with impure op",
node,
node.target,
)
saved_values.append(node)
continue
if node.op == "call_function":
assert is_tensor(node), f"{node}"
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
continue
if not must_recompute(node):
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
return _extract_fwd_bwd_modules(
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
if static_lifetime_input_nodes is None:
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
@ -1127,6 +1160,24 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
if len(node_info.required_bw_nodes) > 0:
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6)
@ -1621,7 +1672,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
break
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
def is_getitem_of_multi_output(node):
if node.target != operator.getitem:
return False
parent = node.args[0]
return "tensor_meta" not in parent.meta and node.op == "call_function"
def cleanup_recompute_tags(
joint_module: fx.GraphModule, *, is_default_partition: bool
) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of
@ -1658,6 +1718,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
elif (
"ac_graph_id" not in node.meta
and any(must_recompute(user) for user in node.users)
and not (
# Avoid saving getitem nodes which are not labeled with "ac_graph_id"
is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta
)
and is_default_partition
):
# This node is not part of the AC region and a user is marked as recompute.
# This means it's an input to the AC region and we should save it.
# For ease of landing, gate this to default partitioner only, but we should think
# about flipping the switch in general as well.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module
@ -2765,6 +2839,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
return module
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
def min_cut_rematerialization_partition(
joint_module: fx.GraphModule,
_joint_inputs,
@ -2813,68 +2940,16 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
# networkx blows up on graphs with no required backward nodes
# Since there's nothing to partition anyway, and the default partitioner can "handle"

View File

@ -1218,6 +1218,33 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
END_HANDLE_TH_ERRORS
}
static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) {
HANDLE_TH_ERRORS
if (obj == Py_None) {
c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt);
} else {
Py_INCREF(obj);
c10::AutogradState::get_tls_state().set_graph_exec_group(
c10::SafePyObject(obj, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* get_graph_exec_group(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
const auto& group =
c10::AutogradState::get_tls_state().get_graph_exec_group();
if (group.has_value()) {
PyObject* obj = group->ptr(getPyInterpreter());
Py_INCREF(obj);
return obj;
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (c10::InferenceMode::is_enabled()) {
@ -1598,6 +1625,8 @@ static PyMethodDef methods[] = {
castPyCFunctionWithKeywords(set_view_replay_enabled),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_set_graph_exec_group", set_graph_exec_group, METH_O, nullptr},
{"_get_graph_exec_group", get_graph_exec_group, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level",
castPyCFunctionWithKeywords(python_exit_dual_level),

View File

@ -33,6 +33,7 @@ __all__ = [
"SelectiveCheckpointContext",
"create_selective_checkpoint_contexts",
"SAC_IGNORED_OPS",
"GraphExecGroup",
]
_DEFAULT_DETERMINISM_MODE = "default"
@ -1072,7 +1073,7 @@ class _StopRecomputationError(Exception):
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None:
def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]) -> None:
def pack_hook(x):
x = x.detach() if x.requires_grad else x
target_frame = target_frame_ref()
@ -1145,10 +1146,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
return holder
def unpack_hook(holder):
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
# First check if we're inside a GraphExecGroup context
gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group()
if gid is None:
# Fallback to using the current graph task id
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
if not frame.is_recomputed[gid]:
ctx = frame.input_saver.grad_fn
@ -1168,10 +1173,17 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
_internal_assert(gid in holder.handles)
if holder.handles[gid] is None:
extra = ""
if torch._C._get_graph_exec_group() is not None:
extra = (
"Performing two backward calls that overlap (i.e. require the same "
"saved activation in order to compute gradients) is not allowed while "
"under the torch.utils.checkpoint.GraphExecGroup context. "
)
raise CheckpointError(
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
"so only once. Otherwise please open an issue with details on your use case."
f"unpacked once. {extra}If you are calling ctx.saved_tensors in backward, make sure "
"to do so only once. Otherwise please open an issue with details on your use case."
)
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
ret = frame.recomputed[gid][holder.handles[gid]]
@ -1594,6 +1606,40 @@ def _checkpoint_without_reentrant_generator(
return
class GraphExecGroup:
"""Any checkpointed regions encountered by backward under the same instance
of this context manager will trigger recompute at most once, even if
there are multiple calls to backward.
Backward calls under the same instance of this context manager must execute
over non-overlapping regions of the backward graph even if retain_graph=True.
In particular, any two backward call cannot use the same saved activation for
gradient computation.
.. note::
This context manager only affects checkpoint with use_reentrant=False, and
is a no-op otherwise.
"""
def __enter__(self) -> "GraphExecGroup":
if torch._C._get_graph_exec_group() is not None:
raise RuntimeError(
"GraphExecGroup contexts cannot be nested. "
f"Already inside group {torch._C._get_graph_exec_group()}"
)
torch._C._set_graph_exec_group(self)
return self
def __exit__(self, *args: object) -> None:
torch._C._set_graph_exec_group(None)
@classmethod
def _get_current_group(cls) -> Optional["GraphExecGroup"]:
# Private API to be used by utils like AC
return torch._C._get_graph_exec_group()
# Note: [compiled autograd and checkpoint unpack hook]
# When tracing via compiled autograd, this hook will be visible to the
# compiler if the forward of this checkpointed region ran in eager.