mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
Compare commits
31 Commits
docs
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 19fc2d35d7 | |||
| 8a02e7fd14 | |||
| eea6b6f0e3 | |||
| b2f6928f34 | |||
| 09bb739aa3 | |||
| 983b20821d | |||
| dede07ca26 | |||
| eecc531f81 | |||
| ca7270992c | |||
| eb5947d25f | |||
| 1b94305bd0 | |||
| c1e0a9318d | |||
| 5c95ba1705 | |||
| fdbf5935c2 | |||
| be9493cbd8 | |||
| 10b2b1a8bc | |||
| 46eec57fbc | |||
| 1b9b8f52ae | |||
| 0d559d0c20 | |||
| 70714103b1 | |||
| 5d7e730359 | |||
| 363b1d2b49 | |||
| f278c43737 | |||
| 2870894809 | |||
| cb0bb1d8bb | |||
| 66990f8dea | |||
| a1ee245e3e | |||
| 9266afcde2 | |||
| 6913ecb72e | |||
| a3d40e72f2 | |||
| 4708491c8d |
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <optional>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -15,7 +17,8 @@ struct C10_API AutogradState {
|
||||
bool inference_mode,
|
||||
bool fw_grad_mode,
|
||||
bool multithreading_enabled)
|
||||
: grad_mode_(grad_mode),
|
||||
: graph_exec_group_(std::nullopt),
|
||||
grad_mode_(grad_mode),
|
||||
inference_mode_(inference_mode),
|
||||
fw_grad_mode_(fw_grad_mode),
|
||||
multithreading_enabled_(multithreading_enabled),
|
||||
@ -41,6 +44,10 @@ struct C10_API AutogradState {
|
||||
view_replay_enabled_ = view_replay_enabled;
|
||||
}
|
||||
|
||||
void set_graph_exec_group(std::optional<SafePyObject> group) {
|
||||
graph_exec_group_ = std::move(group);
|
||||
}
|
||||
|
||||
bool get_grad_mode() const {
|
||||
return grad_mode_;
|
||||
}
|
||||
@ -61,7 +68,12 @@ struct C10_API AutogradState {
|
||||
return view_replay_enabled_;
|
||||
}
|
||||
|
||||
const std::optional<SafePyObject>& get_graph_exec_group() const {
|
||||
return graph_exec_group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<SafePyObject> graph_exec_group_;
|
||||
bool grad_mode_ : 1;
|
||||
bool inference_mode_ : 1;
|
||||
bool fw_grad_mode_ : 1;
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -372,7 +371,6 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
|
||||
# is producing a joint graph with backward region missing
|
||||
@unittest.expectedFailure
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
|
||||
@ -13,12 +13,12 @@ import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from functorch.compile import default_partition, min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
@ -275,7 +275,14 @@ class ActivationCheckpointingViaTagsTests(
|
||||
|
||||
run(export_compiler)
|
||||
|
||||
def test_tags_function(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -291,11 +298,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_via_global_checkpoint(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -310,17 +328,28 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_with_kwargs(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_with_kwargs(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
||||
gn, torch.sin(x), y, use_reentrant=False
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
@ -330,11 +359,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_sequential_layers(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_sequential_layers(self, device, partition_fn):
|
||||
def gn(x):
|
||||
x = x.cos()
|
||||
for _ in range(3):
|
||||
@ -355,11 +395,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
freqs=[2, 18],
|
||||
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_multiple_checkpoints(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_multiple_checkpoints(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -377,11 +428,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=6, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_module(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_module(self, device, partition_fn):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -405,11 +467,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_decomps(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_decomps(self, device, partition_fn):
|
||||
# Ensures that tags are passed on through decompositions as well
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -437,6 +510,7 @@ class ActivationCheckpointingViaTagsTests(
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
@ -696,7 +770,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
|
||||
def context_fn_must_recompute_mm():
|
||||
must_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -717,9 +798,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def _test(context_fn, bw_compiler):
|
||||
def _test(context_fn, bw_compiler, partition_fn):
|
||||
def gn(x):
|
||||
return torch.sigmoid(torch.matmul(x, x))
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
@ -733,14 +814,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=1,
|
||||
freq=2,
|
||||
op=torch.ops.aten.mm.default,
|
||||
)
|
||||
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@ -748,17 +829,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
context_fn=context_fn_must_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
||||
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
_test(
|
||||
context_fn=context_fn_no_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=2, # 2 bwd mm ops per fwd matmul
|
||||
freq=4, # 2 bwd mm ops per fwd matmul
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
@ -795,7 +878,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -835,15 +927,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
@ -883,7 +982,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
@ -891,7 +990,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
|
||||
# Copy of the above test, but make sure that having a triton kernel in the
|
||||
# region does not error.
|
||||
def add_one(x):
|
||||
@ -951,14 +1057,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1001,14 +1114,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
|
||||
def _get_custom_policy(meta):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1066,14 +1186,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn(no_recompute_list):
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
@ -1112,14 +1239,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1157,14 +1291,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_list_ops(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
# recompute everything
|
||||
no_recompute_list = []
|
||||
@ -1200,7 +1341,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1211,7 +1352,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
"requires TorchDispatchMode + torch.compile work to complete"
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1251,7 +1399,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1259,7 +1407,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_compile_selective_checkpoint_random_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
|
||||
for preserve_rng_state in [True, False]:
|
||||
|
||||
def selective_checkpointing_context_fn():
|
||||
@ -1306,7 +1461,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
||||
@ -1318,7 +1473,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_invalid_context(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y)) * y
|
||||
|
||||
@ -1347,7 +1509,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
||||
@ -1356,7 +1518,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
def test_compile_selective_checkpoint_parametrization(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
|
||||
def sac_policy():
|
||||
def _recomp_policy():
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
@ -1419,7 +1588,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[
|
||||
2, # 1 from mul recompute, 1 from mul backward
|
||||
# 1 from mul recompute, 1 from mul backward
|
||||
# w/o CSE, we have one extra mul
|
||||
3 if partition_fn is default_partition else 2,
|
||||
1,
|
||||
],
|
||||
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
||||
@ -1428,7 +1599,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
model = MLPModule()
|
||||
|
||||
@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
|
||||
return grad_output * x, grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
return FwBwMutation.apply(a, b)
|
||||
return FwBwMutation.apply(a, b).sin_().clone()
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
return (mul, add)""",
|
||||
clone = torch.ops.aten.clone.default(mul)
|
||||
sin_ = torch.ops.aten.sin_.default(mul); mul = None
|
||||
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
|
||||
return (clone_1, add, clone)""",
|
||||
)
|
||||
|
||||
# important bit: there is 1 mutation in the bw
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, tangents_1):
|
||||
def forward(self, add, clone, tangents_1):
|
||||
cos = torch.ops.aten.cos.default(clone); clone = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
|
||||
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
|
||||
return (mul_2, None)""",
|
||||
)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization2(self):
|
||||
|
||||
@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
self.assertEqual(len(mm_nodes), 4)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
|
||||
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
|
||||
|
||||
@ -5222,6 +5222,7 @@ xfail_by_backend = {
|
||||
"test_reentrant_with_callbacks_both_depths", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_0", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_1", # queue_callback
|
||||
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
|
||||
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
|
||||
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
||||
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
||||
|
||||
@ -7364,6 +7364,62 @@ for shape in [(1,), ()]:
|
||||
):
|
||||
checkpoint_sequential(modules_list, 3, a)
|
||||
|
||||
@skipIfTorchDynamo("GraphExecGroup does not support compile")
|
||||
def test_checkpoint_graph_execution_group(self):
|
||||
def run(use_graph_execution_group):
|
||||
counter = [0]
|
||||
|
||||
def fn(x):
|
||||
counter[0] += 1
|
||||
y = x.sin().cos()
|
||||
z = y.sin().cos()
|
||||
return y, z
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
|
||||
y, z = checkpoint(fn, x, use_reentrant=False)
|
||||
|
||||
group = torch.utils.checkpoint.GraphExecGroup()
|
||||
|
||||
ctx = contextlib.nullcontext()
|
||||
if use_graph_execution_group:
|
||||
ctx = group
|
||||
|
||||
with ctx:
|
||||
(grad_y,) = torch.autograd.grad(
|
||||
z, inputs=(y,), grad_outputs=(torch.ones(3, 3),)
|
||||
)
|
||||
|
||||
(grad_x,) = torch.autograd.grad(
|
||||
y,
|
||||
inputs=(x,),
|
||||
grad_outputs=(grad_y,),
|
||||
)
|
||||
|
||||
if use_graph_execution_group:
|
||||
self.assertEqual(counter[0], 2)
|
||||
else:
|
||||
self.assertEqual(counter[0], 3)
|
||||
|
||||
run(use_graph_execution_group=True)
|
||||
run(use_graph_execution_group=False)
|
||||
|
||||
# Test the not actually disjoint case (using retain_graph=True since
|
||||
# otherwise autograd itself will catch this)
|
||||
def fn(x):
|
||||
return x.sin().cos()
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
out = checkpoint(fn, x, use_reentrant=False)
|
||||
with torch.utils.checkpoint.GraphExecGroup():
|
||||
# Under this context, we will enforce that two backward are disjoint
|
||||
# even if retain_graph=True.
|
||||
out.sum().backward(retain_graph=True)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Performing two backward calls that overlap"
|
||||
):
|
||||
out.sum().backward()
|
||||
|
||||
def test_checkpoint_detects_non_determinism(self):
|
||||
def save_3_tensors(x):
|
||||
out = x.sin().exp()
|
||||
|
||||
@ -69,6 +69,7 @@ from torch.types import (
|
||||
Storage,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.checkpoint import GraphExecGroup
|
||||
|
||||
# This module is defined in torch/csrc/Module.cpp
|
||||
|
||||
@ -1491,6 +1492,8 @@ def _is_multithreading_enabled() -> _bool: ...
|
||||
def _set_multithreading_enabled(enabled: _bool) -> None: ...
|
||||
def _set_view_replay_enabled(enabled: _bool) -> None: ...
|
||||
def _is_view_replay_enabled() -> _bool: ...
|
||||
def _set_graph_exec_group(group: GraphExecGroup | None) -> None: ...
|
||||
def _get_graph_exec_group() -> GraphExecGroup | None: ...
|
||||
def _enter_dual_level() -> _int: ...
|
||||
def _exit_dual_level(level: _int) -> None: ...
|
||||
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
|
||||
|
||||
@ -10,6 +10,7 @@ This file contains utilities related to functionalization in AOTAutograd:
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg):
|
||||
|
||||
|
||||
# Returns the number of detected copy_
|
||||
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]:
|
||||
allowed_mutation_ops = [
|
||||
torch.ops.aten.copy_.default,
|
||||
torch.ops.aten.set_.source_Tensor,
|
||||
@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
# NB: It would also be nice to verify that the mutations all happen at the
|
||||
# end, but we also do some administrative views after mutations so this
|
||||
# isn't actually true. (TODO: Could this cause problems for Inductor?)
|
||||
error = None
|
||||
for n in fx_g.nodes:
|
||||
if n.op == "placeholder":
|
||||
placeholders.add(n)
|
||||
@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
# this is mostly a hack to avoid failing XLA tests.
|
||||
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
|
||||
if "set_buffer_donor_" not in str(n.args[0]):
|
||||
assert n.args[0] in placeholders, (
|
||||
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
)
|
||||
if n.args[0] not in placeholders:
|
||||
error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
mutation_count += 1
|
||||
else:
|
||||
assert not n.target._schema.is_mutable, (
|
||||
f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
|
||||
)
|
||||
if n.target._schema.is_mutable:
|
||||
error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
|
||||
return error, mutation_count
|
||||
|
||||
|
||||
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
error, mutation_count = _is_functional_graph(fx_g)
|
||||
assert error is None, error
|
||||
return mutation_count
|
||||
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_proxy_tensor_disable_update_tensor_tracker,
|
||||
get_proxy_mode,
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
@ -295,6 +296,10 @@ def create_joint(
|
||||
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
|
||||
fn, primals
|
||||
)
|
||||
mode = get_proxy_mode()
|
||||
assert mode is not None, "Expected non-None proxy mode"
|
||||
for node in mode.tracer.graph.nodes:
|
||||
node.meta["partitioner_tag"] = "is_forward"
|
||||
|
||||
# TODO: I think this hook can also be eliminated now
|
||||
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||
|
||||
@ -10,6 +10,7 @@ import operator
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, replace
|
||||
@ -51,6 +52,7 @@ from ._activation_checkpointing.knapsack import (
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
|
||||
from ._aot_autograd.functional_utils import _is_functional_graph
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
|
||||
@ -297,6 +299,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||
|
||||
|
||||
def _has_tag_is_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_forward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||
|
||||
@ -1021,105 +1027,132 @@ def default_partition(
|
||||
Returns:
|
||||
Returns the generated forward and backward Fx graph modules.
|
||||
"""
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
# Respect the original placement of ops rather than rely on dataflow.
|
||||
forward_nodes = []
|
||||
last_node = None
|
||||
for node in joint_module.graph.nodes:
|
||||
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
|
||||
last_node = node
|
||||
assert last_node is not None
|
||||
for node in joint_module.graph.nodes:
|
||||
if not _is_tangent(node):
|
||||
forward_nodes.append(node)
|
||||
if node is last_node:
|
||||
break
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
node.name for node in forward_nodes if node.op != "output"
|
||||
)
|
||||
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
if _is_functional_graph(joint_module.graph)[0] is not None:
|
||||
# Fall-back to previous behavior to avoid bc-breaking, although can
|
||||
# eventually flip the switch to make this a hard error.
|
||||
warnings.warn(
|
||||
"Trying to unsafely apply AC to a non-functional graph with the "
|
||||
"default partitioner. Falling back to min-cut partitioner."
|
||||
)
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
|
||||
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
def is_mutated_later_in_fw(node):
|
||||
if _has_tag_is_backward(node):
|
||||
return False
|
||||
tensor_arg_aliases = [
|
||||
x
|
||||
for x in node.args
|
||||
if isinstance(x, fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.Tensor)
|
||||
]
|
||||
while len(tensor_arg_aliases) > 0:
|
||||
a = tensor_arg_aliases.pop()
|
||||
for u in a.users:
|
||||
if not isinstance(u.target, torch._ops.OpOverload):
|
||||
continue
|
||||
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
|
||||
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
|
||||
if (
|
||||
# one of the args was mutated
|
||||
u.target._schema.is_mutable
|
||||
# and the mutation happens "later"
|
||||
and order[u] > order[node]
|
||||
# and the mutation happened during the forward
|
||||
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
|
||||
):
|
||||
for idx, alias_info in enumerate(u.target._schema.arguments):
|
||||
if alias_info.is_write and u.args[idx] is a:
|
||||
return True
|
||||
elif u.target.is_view:
|
||||
tensor_arg_aliases.append(u)
|
||||
return False
|
||||
def is_tensor(node):
|
||||
# This node returns a single tensor output
|
||||
return (
|
||||
"tensor_meta" in node.meta
|
||||
and node.op == "call_function"
|
||||
and isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
)
|
||||
|
||||
def is_multi_output(node):
|
||||
return (
|
||||
not is_tensor(node)
|
||||
and all(user.target == operator.getitem for user in node.users)
|
||||
and len(node.users) > 0
|
||||
)
|
||||
|
||||
def is_impure(node):
|
||||
# wait tensor is an "impure" op according to DCE's definition of impure
|
||||
# (see is_impure in torch/fx/node.py), but it survives past
|
||||
# functionalization and can be safely dup'd and reordered under the
|
||||
# assumption SPMD.
|
||||
return (
|
||||
node.is_impure(impure_random=False)
|
||||
and node.op
|
||||
not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
)
|
||||
and node.target is not torch.ops._c10d_functional.wait_tensor.default
|
||||
)
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
# if a node isn't "required" to be in the forward, but any of its arguments
|
||||
# are later mutated in the forward, then it must have been run in the forward
|
||||
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
|
||||
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
|
||||
if is_mutated_later_in_fw(node):
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
elif (
|
||||
"tensor_meta" not in node.meta
|
||||
and node.op == "call_function"
|
||||
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
):
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target is operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [
|
||||
n for n in node.users if n.name not in forward_node_names
|
||||
]
|
||||
if "tensor_meta" in node.meta and all(
|
||||
is_sym_node(n) for n in backward_usages
|
||||
):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
else:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_multi_output(node):
|
||||
# Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE.
|
||||
continue
|
||||
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_impure(node):
|
||||
assert not graph_has_recomputable_ops, (
|
||||
"Trying to apply AC on a graph with impure op",
|
||||
node,
|
||||
node.target,
|
||||
)
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if node.op == "call_function":
|
||||
assert is_tensor(node), f"{node}"
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
if all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
continue
|
||||
if not must_recompute(node):
|
||||
saved_values.append(node)
|
||||
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
return _extract_fwd_bwd_modules(
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
|
||||
|
||||
if static_lifetime_input_nodes is None:
|
||||
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
@ -1127,6 +1160,24 @@ def default_partition(
|
||||
static_lifetime_input_nodes=static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
||||
)
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
# raise all getitem ops to as early as possible
|
||||
# this is helpful for memory, especially in the case of aot_eager backend
|
||||
fw_module = raise_getitems(fw_module)
|
||||
bw_module = raise_getitems(bw_module)
|
||||
|
||||
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
|
||||
if len(node_info.required_bw_nodes) > 0:
|
||||
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
|
||||
|
||||
return fw_module, bw_module
|
||||
|
||||
|
||||
INT_INF = int(1e6)
|
||||
|
||||
@ -1621,7 +1672,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||
break
|
||||
|
||||
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
def is_getitem_of_multi_output(node):
|
||||
if node.target != operator.getitem:
|
||||
return False
|
||||
parent = node.args[0]
|
||||
return "tensor_meta" not in parent.meta and node.op == "call_function"
|
||||
|
||||
|
||||
def cleanup_recompute_tags(
|
||||
joint_module: fx.GraphModule, *, is_default_partition: bool
|
||||
) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
between, we would still want to stash the tensor at the boundary of
|
||||
@ -1658,6 +1718,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
|
||||
# in forward graph outputs. With this, we can break the above circular dependency.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
elif (
|
||||
"ac_graph_id" not in node.meta
|
||||
and any(must_recompute(user) for user in node.users)
|
||||
and not (
|
||||
# Avoid saving getitem nodes which are not labeled with "ac_graph_id"
|
||||
is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta
|
||||
)
|
||||
and is_default_partition
|
||||
):
|
||||
# This node is not part of the AC region and a user is marked as recompute.
|
||||
# This means it's an input to the AC region and we should save it.
|
||||
# For ease of landing, gate this to default partitioner only, but we should think
|
||||
# about flipping the switch in general as well.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
return joint_module
|
||||
|
||||
|
||||
@ -2765,6 +2839,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
||||
return module
|
||||
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
|
||||
def min_cut_rematerialization_partition(
|
||||
joint_module: fx.GraphModule,
|
||||
_joint_inputs,
|
||||
@ -2813,68 +2940,16 @@ def min_cut_rematerialization_partition(
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
||||
)
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
# networkx blows up on graphs with no required backward nodes
|
||||
# Since there's nothing to partition anyway, and the default partitioner can "handle"
|
||||
|
||||
@ -1218,6 +1218,33 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (obj == Py_None) {
|
||||
c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt);
|
||||
} else {
|
||||
Py_INCREF(obj);
|
||||
c10::AutogradState::get_tls_state().set_graph_exec_group(
|
||||
c10::SafePyObject(obj, getPyInterpreter()));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_graph_exec_group(PyObject* self, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& group =
|
||||
c10::AutogradState::get_tls_state().get_graph_exec_group();
|
||||
if (group.has_value()) {
|
||||
PyObject* obj = group->ptr(getPyInterpreter());
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
} else {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (c10::InferenceMode::is_enabled()) {
|
||||
@ -1598,6 +1625,8 @@ static PyMethodDef methods[] = {
|
||||
castPyCFunctionWithKeywords(set_view_replay_enabled),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_set_graph_exec_group", set_graph_exec_group, METH_O, nullptr},
|
||||
{"_get_graph_exec_group", get_graph_exec_group, METH_NOARGS, nullptr},
|
||||
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
|
||||
{"_exit_dual_level",
|
||||
castPyCFunctionWithKeywords(python_exit_dual_level),
|
||||
|
||||
@ -33,6 +33,7 @@ __all__ = [
|
||||
"SelectiveCheckpointContext",
|
||||
"create_selective_checkpoint_contexts",
|
||||
"SAC_IGNORED_OPS",
|
||||
"GraphExecGroup",
|
||||
]
|
||||
|
||||
_DEFAULT_DETERMINISM_MODE = "default"
|
||||
@ -1072,7 +1073,7 @@ class _StopRecomputationError(Exception):
|
||||
|
||||
|
||||
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None:
|
||||
def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]) -> None:
|
||||
def pack_hook(x):
|
||||
x = x.detach() if x.requires_grad else x
|
||||
target_frame = target_frame_ref()
|
||||
@ -1145,10 +1146,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
return holder
|
||||
|
||||
def unpack_hook(holder):
|
||||
gid = torch._C._current_graph_task_id()
|
||||
if gid == -1:
|
||||
# generate a temporary id if we trigger unpack outside of a backward call
|
||||
gid = int(uuid.uuid4())
|
||||
# First check if we're inside a GraphExecGroup context
|
||||
gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group()
|
||||
if gid is None:
|
||||
# Fallback to using the current graph task id
|
||||
gid = torch._C._current_graph_task_id()
|
||||
if gid == -1:
|
||||
# generate a temporary id if we trigger unpack outside of a backward call
|
||||
gid = int(uuid.uuid4())
|
||||
|
||||
if not frame.is_recomputed[gid]:
|
||||
ctx = frame.input_saver.grad_fn
|
||||
@ -1168,10 +1173,17 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
_internal_assert(gid in holder.handles)
|
||||
|
||||
if holder.handles[gid] is None:
|
||||
extra = ""
|
||||
if torch._C._get_graph_exec_group() is not None:
|
||||
extra = (
|
||||
"Performing two backward calls that overlap (i.e. require the same "
|
||||
"saved activation in order to compute gradients) is not allowed while "
|
||||
"under the torch.utils.checkpoint.GraphExecGroup context. "
|
||||
)
|
||||
raise CheckpointError(
|
||||
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
|
||||
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
|
||||
"so only once. Otherwise please open an issue with details on your use case."
|
||||
f"unpacked once. {extra}If you are calling ctx.saved_tensors in backward, make sure "
|
||||
"to do so only once. Otherwise please open an issue with details on your use case."
|
||||
)
|
||||
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
|
||||
ret = frame.recomputed[gid][holder.handles[gid]]
|
||||
@ -1594,6 +1606,40 @@ def _checkpoint_without_reentrant_generator(
|
||||
|
||||
return
|
||||
|
||||
|
||||
class GraphExecGroup:
|
||||
"""Any checkpointed regions encountered by backward under the same instance
|
||||
of this context manager will trigger recompute at most once, even if
|
||||
there are multiple calls to backward.
|
||||
|
||||
Backward calls under the same instance of this context manager must execute
|
||||
over non-overlapping regions of the backward graph even if retain_graph=True.
|
||||
In particular, any two backward call cannot use the same saved activation for
|
||||
gradient computation.
|
||||
|
||||
.. note::
|
||||
This context manager only affects checkpoint with use_reentrant=False, and
|
||||
is a no-op otherwise.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "GraphExecGroup":
|
||||
if torch._C._get_graph_exec_group() is not None:
|
||||
raise RuntimeError(
|
||||
"GraphExecGroup contexts cannot be nested. "
|
||||
f"Already inside group {torch._C._get_graph_exec_group()}"
|
||||
)
|
||||
torch._C._set_graph_exec_group(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
torch._C._set_graph_exec_group(None)
|
||||
|
||||
@classmethod
|
||||
def _get_current_group(cls) -> Optional["GraphExecGroup"]:
|
||||
# Private API to be used by utils like AC
|
||||
return torch._C._get_graph_exec_group()
|
||||
|
||||
|
||||
# Note: [compiled autograd and checkpoint unpack hook]
|
||||
# When tracing via compiled autograd, this hook will be visible to the
|
||||
# compiler if the forward of this checkpointed region ran in eager.
|
||||
|
||||
Reference in New Issue
Block a user