mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 23:44:53 +08:00
Compare commits
5 Commits
ciflow/vll
...
whc/pp_fix
| Author | SHA1 | Date | |
|---|---|---|---|
| aeca99aecc | |||
| 8b6f8966ce | |||
| 76379a71ac | |||
| 85a07dd185 | |||
| 9650b7407b |
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -357,7 +356,6 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
|
||||
# is producing a joint graph with backward region missing
|
||||
@unittest.expectedFailure
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from functorch.compile import default_partition, min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
|
||||
)
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
@ -281,7 +281,14 @@ class ActivationCheckpointingViaTagsTests(
|
||||
|
||||
run(export_compiler)
|
||||
|
||||
def test_tags_function(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -297,11 +304,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_via_global_checkpoint(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -316,17 +334,28 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_with_kwargs(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_with_kwargs(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
||||
gn, torch.sin(x), y, use_reentrant=False
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
@ -336,11 +365,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_sequential_layers(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_sequential_layers(self, device, partition_fn):
|
||||
def gn(x):
|
||||
x = x.cos()
|
||||
for _ in range(3):
|
||||
@ -361,11 +401,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
freqs=[2, 18],
|
||||
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_multiple_checkpoints(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_multiple_checkpoints(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -383,11 +434,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=6, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_module(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_module(self, device, partition_fn):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -411,11 +473,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_decomps(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_decomps(self, device, partition_fn):
|
||||
# Ensures that tags are passed on through decompositions as well
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -443,6 +516,7 @@ class ActivationCheckpointingViaTagsTests(
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
@ -702,7 +776,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
|
||||
def context_fn_must_recompute_mm():
|
||||
must_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -723,9 +804,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def _test(context_fn, bw_compiler):
|
||||
def _test(context_fn, bw_compiler, partition_fn):
|
||||
def gn(x):
|
||||
return torch.sigmoid(torch.matmul(x, x))
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
@ -739,14 +820,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=1,
|
||||
freq=2,
|
||||
op=torch.ops.aten.mm.default,
|
||||
)
|
||||
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@ -754,17 +835,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
context_fn=context_fn_must_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
||||
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
_test(
|
||||
context_fn=context_fn_no_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=2, # 2 bwd mm ops per fwd matmul
|
||||
freq=4, # 2 bwd mm ops per fwd matmul
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
@ -801,7 +884,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -841,15 +933,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
@ -889,7 +988,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
@ -897,7 +996,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
|
||||
# Copy of the above test, but make sure that having a triton kernel in the
|
||||
# region does not error.
|
||||
def add_one(x):
|
||||
@ -957,14 +1063,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1007,14 +1120,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
|
||||
def _get_custom_policy(meta):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1072,14 +1192,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn(no_recompute_list):
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
@ -1118,14 +1245,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1163,14 +1297,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_list_ops(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
# recompute everything
|
||||
no_recompute_list = []
|
||||
@ -1206,7 +1347,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1217,7 +1358,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
"requires TorchDispatchMode + torch.compile work to complete"
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1257,7 +1405,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1265,7 +1413,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_compile_selective_checkpoint_random_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
|
||||
for preserve_rng_state in [True, False]:
|
||||
|
||||
def selective_checkpointing_context_fn():
|
||||
@ -1312,7 +1467,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
||||
@ -1324,7 +1479,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_invalid_context(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y)) * y
|
||||
|
||||
@ -1353,7 +1515,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
||||
@ -1362,7 +1524,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
def test_compile_selective_checkpoint_parametrization(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
|
||||
def sac_policy():
|
||||
def _recomp_policy():
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
@ -1425,7 +1594,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[
|
||||
2, # 1 from mul recompute, 1 from mul backward
|
||||
# 1 from mul recompute, 1 from mul backward
|
||||
# w/o CSE, we have one extra mul
|
||||
3 if partition_fn is default_partition else 2,
|
||||
1,
|
||||
],
|
||||
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
||||
@ -1434,7 +1605,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
model = MLPModule()
|
||||
|
||||
@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
|
||||
return grad_output * x, grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
return FwBwMutation.apply(a, b)
|
||||
return FwBwMutation.apply(a, b).sin_().clone()
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
return (mul, add)""",
|
||||
clone = torch.ops.aten.clone.default(mul)
|
||||
sin_ = torch.ops.aten.sin_.default(mul); mul = None
|
||||
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
|
||||
return (clone_1, add, clone)""",
|
||||
)
|
||||
|
||||
# important bit: there is 1 mutation in the bw
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, tangents_1):
|
||||
def forward(self, add, clone, tangents_1):
|
||||
cos = torch.ops.aten.cos.default(clone); clone = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
|
||||
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
|
||||
return (mul_2, None)""",
|
||||
)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization2(self):
|
||||
|
||||
@ -911,8 +911,8 @@ class GraphModule(torch.nn.Module):
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
self.assertEqual(len(mm_nodes), 4)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
|
||||
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
|
||||
|
||||
@ -10,6 +10,7 @@ This file contains utilities related to functionalization in AOTAutograd:
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg):
|
||||
|
||||
|
||||
# Returns the number of detected copy_
|
||||
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]:
|
||||
allowed_mutation_ops = [
|
||||
torch.ops.aten.copy_.default,
|
||||
torch.ops.aten.set_.source_Tensor,
|
||||
@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
# NB: It would also be nice to verify that the mutations all happen at the
|
||||
# end, but we also do some administrative views after mutations so this
|
||||
# isn't actually true. (TODO: Could this cause problems for Inductor?)
|
||||
error = None
|
||||
for n in fx_g.nodes:
|
||||
if n.op == "placeholder":
|
||||
placeholders.add(n)
|
||||
@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
# this is mostly a hack to avoid failing XLA tests.
|
||||
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
|
||||
if "set_buffer_donor_" not in str(n.args[0]):
|
||||
assert n.args[0] in placeholders, (
|
||||
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
)
|
||||
if n.args[0] not in placeholders:
|
||||
error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
mutation_count += 1
|
||||
else:
|
||||
assert not n.target._schema.is_mutable, (
|
||||
f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
|
||||
)
|
||||
if n.target._schema.is_mutable:
|
||||
error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
|
||||
return error, mutation_count
|
||||
|
||||
|
||||
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
error, mutation_count = _is_functional_graph(fx_g)
|
||||
assert error is None, error
|
||||
return mutation_count
|
||||
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_proxy_tensor_disable_update_tensor_tracker,
|
||||
get_proxy_mode,
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
@ -295,6 +296,10 @@ def create_joint(
|
||||
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
|
||||
fn, primals
|
||||
)
|
||||
mode = get_proxy_mode()
|
||||
assert mode is not None, "Expected non-None proxy mode"
|
||||
for node in mode.tracer.graph.nodes:
|
||||
node.meta["partitioner_tag"] = "is_forward"
|
||||
|
||||
# TODO: I think this hook can also be eliminated now
|
||||
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||
|
||||
@ -10,6 +10,7 @@ import operator
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, replace
|
||||
@ -51,6 +52,7 @@ from ._activation_checkpointing.knapsack import (
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
|
||||
from ._aot_autograd.functional_utils import _is_functional_graph
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
|
||||
@ -297,6 +299,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||
|
||||
|
||||
def _has_tag_is_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_forward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||
|
||||
@ -1021,105 +1027,132 @@ def default_partition(
|
||||
Returns:
|
||||
Returns the generated forward and backward Fx graph modules.
|
||||
"""
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
# Respect the original placement of ops rather than rely on dataflow.
|
||||
forward_nodes = []
|
||||
last_node = None
|
||||
for node in joint_module.graph.nodes:
|
||||
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
|
||||
last_node = node
|
||||
assert last_node is not None
|
||||
for node in joint_module.graph.nodes:
|
||||
if not _is_tangent(node):
|
||||
forward_nodes.append(node)
|
||||
if node is last_node:
|
||||
break
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
node.name for node in forward_nodes if node.op != "output"
|
||||
)
|
||||
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
if _is_functional_graph(joint_module.graph)[0] is not None:
|
||||
# Fall-back to previous behavior to avoid bc-breaking, although can
|
||||
# eventually flip the switch to make this a hard error.
|
||||
warnings.warn(
|
||||
"Trying to unsafely apply AC to a non-functional graph with the "
|
||||
"default partitioner. Falling back to min-cut partitioner."
|
||||
)
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
|
||||
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
def is_mutated_later_in_fw(node):
|
||||
if _has_tag_is_backward(node):
|
||||
return False
|
||||
tensor_arg_aliases = [
|
||||
x
|
||||
for x in node.args
|
||||
if isinstance(x, fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.Tensor)
|
||||
]
|
||||
while len(tensor_arg_aliases) > 0:
|
||||
a = tensor_arg_aliases.pop()
|
||||
for u in a.users:
|
||||
if not isinstance(u.target, torch._ops.OpOverload):
|
||||
continue
|
||||
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
|
||||
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
|
||||
if (
|
||||
# one of the args was mutated
|
||||
u.target._schema.is_mutable
|
||||
# and the mutation happens "later"
|
||||
and order[u] > order[node]
|
||||
# and the mutation happened during the forward
|
||||
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
|
||||
):
|
||||
for idx, alias_info in enumerate(u.target._schema.arguments):
|
||||
if alias_info.is_write and u.args[idx] is a:
|
||||
return True
|
||||
elif u.target.is_view:
|
||||
tensor_arg_aliases.append(u)
|
||||
return False
|
||||
def is_tensor(node):
|
||||
# This node returns a single tensor output
|
||||
return (
|
||||
"tensor_meta" in node.meta
|
||||
and node.op == "call_function"
|
||||
and isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
)
|
||||
|
||||
def is_multi_output(node):
|
||||
return (
|
||||
not is_tensor(node)
|
||||
and all(user.target == operator.getitem for user in node.users)
|
||||
and len(node.users) > 0
|
||||
)
|
||||
|
||||
def is_impure(node):
|
||||
# wait tensor is an "impure" op according to DCE's definition of impure
|
||||
# (see is_impure in torch/fx/node.py), but it survives past
|
||||
# functionalization and can be safely dup'd and reordered under the
|
||||
# assumption SPMD.
|
||||
return (
|
||||
node.is_impure(impure_random=False)
|
||||
and node.op
|
||||
not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
)
|
||||
and node.target is not torch.ops._c10d_functional.wait_tensor.default
|
||||
)
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
# if a node isn't "required" to be in the forward, but any of its arguments
|
||||
# are later mutated in the forward, then it must have been run in the forward
|
||||
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
|
||||
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
|
||||
if is_mutated_later_in_fw(node):
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
elif (
|
||||
"tensor_meta" not in node.meta
|
||||
and node.op == "call_function"
|
||||
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
):
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target is operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [
|
||||
n for n in node.users if n.name not in forward_node_names
|
||||
]
|
||||
if "tensor_meta" in node.meta and all(
|
||||
is_sym_node(n) for n in backward_usages
|
||||
):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
else:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_multi_output(node):
|
||||
# Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE.
|
||||
continue
|
||||
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_impure(node):
|
||||
assert not graph_has_recomputable_ops, (
|
||||
"Trying to apply AC on a graph with impure op",
|
||||
node,
|
||||
node.target,
|
||||
)
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if node.op == "call_function":
|
||||
assert is_tensor(node), f"{node}"
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
if all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
continue
|
||||
if not must_recompute(node):
|
||||
saved_values.append(node)
|
||||
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
return _extract_fwd_bwd_modules(
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
|
||||
|
||||
if static_lifetime_input_nodes is None:
|
||||
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
@ -1127,6 +1160,24 @@ def default_partition(
|
||||
static_lifetime_input_nodes=static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
||||
)
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
# raise all getitem ops to as early as possible
|
||||
# this is helpful for memory, especially in the case of aot_eager backend
|
||||
fw_module = raise_getitems(fw_module)
|
||||
bw_module = raise_getitems(bw_module)
|
||||
|
||||
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
|
||||
if len(node_info.required_bw_nodes) > 0:
|
||||
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
|
||||
|
||||
return fw_module, bw_module
|
||||
|
||||
|
||||
INT_INF = int(1e6)
|
||||
|
||||
@ -1621,7 +1672,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||
break
|
||||
|
||||
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
def is_getitem_of_multi_output(node):
|
||||
if node.target != operator.getitem:
|
||||
return False
|
||||
parent = node.args[0]
|
||||
return "tensor_meta" not in parent.meta and node.op == "call_function"
|
||||
|
||||
|
||||
def cleanup_recompute_tags(
|
||||
joint_module: fx.GraphModule, *, is_default_partition: bool
|
||||
) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
between, we would still want to stash the tensor at the boundary of
|
||||
@ -1658,6 +1718,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
|
||||
# in forward graph outputs. With this, we can break the above circular dependency.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
elif (
|
||||
"ac_graph_id" not in node.meta
|
||||
and any(must_recompute(user) for user in node.users)
|
||||
and not (
|
||||
# Avoid saving getitem nodes which are not labeled with "ac_graph_id"
|
||||
is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta
|
||||
)
|
||||
and is_default_partition
|
||||
):
|
||||
# This node is not part of the AC region and a user is marked as recompute.
|
||||
# This means it's an input to the AC region and we should save it.
|
||||
# For ease of landing, gate this to default partitioner only, but we should think
|
||||
# about flipping the switch in general as well.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
return joint_module
|
||||
|
||||
|
||||
@ -2765,6 +2839,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
||||
return module
|
||||
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
|
||||
def min_cut_rematerialization_partition(
|
||||
joint_module: fx.GraphModule,
|
||||
_joint_inputs,
|
||||
@ -2813,68 +2940,16 @@ def min_cut_rematerialization_partition(
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
||||
)
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
# networkx blows up on graphs with no required backward nodes
|
||||
# Since there's nothing to partition anyway, and the default partitioner can "handle"
|
||||
|
||||
@ -459,6 +459,22 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
}
|
||||
});
|
||||
m.def("_clear_callbacks", []() { at::clearCallbacks(); });
|
||||
m.def(
|
||||
"_create_ownership_token",
|
||||
[](py::handle grad_fn) {
|
||||
return py::reinterpret_steal<py::object>(
|
||||
THPFunction_create_ownership_token(grad_fn.ptr()));
|
||||
},
|
||||
py::arg("grad_fn"),
|
||||
R"(
|
||||
Create an ownership token for a grad_fn that keeps its cdata alive.
|
||||
|
||||
Args:
|
||||
grad_fn: The grad_fn (autograd.Function) to create an ownership token for.
|
||||
|
||||
Returns:
|
||||
An ownership token object that keeps the underlying C++ Node alive.
|
||||
)");
|
||||
m.def(
|
||||
"_saved_tensors_hooks_is_enabled",
|
||||
at::SavedTensorDefaultHooks::is_enabled);
|
||||
|
||||
@ -51,6 +51,13 @@ using at::Tensor;
|
||||
PyObject* THPFunctionClass = nullptr;
|
||||
PyObject* THPGradientEdgeClass = nullptr;
|
||||
|
||||
// CDataOwner: A simple object that holds a shared_ptr to PyNode to keep it alive
|
||||
// This provides an "ownership token" for cdata
|
||||
struct THPCDataOwner {
|
||||
PyObject_HEAD
|
||||
std::shared_ptr<PyNode> cdata;
|
||||
};
|
||||
|
||||
#define THPFunction_assert(condition, ...) \
|
||||
if (!(condition)) { \
|
||||
THPUtils_setError(__VA_ARGS__); \
|
||||
@ -1785,6 +1792,96 @@ static struct PyMethodDef THPFunction_methods[] = {
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
// Forward declarations for CDataOwner
|
||||
static void THPCDataOwner_dealloc(THPCDataOwner* self);
|
||||
static PyObject* THPCDataOwner_new(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwargs);
|
||||
|
||||
// CDataOwner type implementation
|
||||
static void THPCDataOwner_dealloc(THPCDataOwner* self) {
|
||||
self->cdata.~shared_ptr<PyNode>();
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
static PyObject* THPCDataOwner_new(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
if (!obj)
|
||||
return nullptr;
|
||||
THPCDataOwner* self = (THPCDataOwner*)obj;
|
||||
new (&self->cdata) std::shared_ptr<PyNode>();
|
||||
return obj;
|
||||
}
|
||||
|
||||
static PyTypeObject THPCDataOwnerType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._CDataOwner", /* tp_name */
|
||||
sizeof(THPCDataOwner), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)THPCDataOwner_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
"Ownership token for C++ autograd Node data", /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
nullptr, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
THPCDataOwner_new /* tp_new */
|
||||
};
|
||||
|
||||
// Function to create an ownership token from a grad_fn
|
||||
PyObject* THPFunction_create_ownership_token(PyObject* grad_fn) {
|
||||
// Return None if not a THPFunction (e.g., C++ defined nodes don't need tokens)
|
||||
if (!THPFunction_Check(grad_fn)) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
auto* thp_fn = (THPFunction*)grad_fn;
|
||||
auto cdata = thp_fn->cdata.lock();
|
||||
|
||||
TORCH_CHECK(
|
||||
cdata,
|
||||
"Cannot create ownership token: the underlying PyNode has already been deallocated");
|
||||
|
||||
PyObject* obj = THPCDataOwnerType.tp_alloc(&THPCDataOwnerType, 0);
|
||||
if (!obj)
|
||||
throw python_error();
|
||||
|
||||
THPCDataOwner* owner = (THPCDataOwner*)obj;
|
||||
new (&owner->cdata) std::shared_ptr<PyNode>(std::move(cdata));
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
PyTypeObject THPFunctionType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._FunctionBase", /* tp_name */
|
||||
@ -1833,5 +1930,12 @@ bool THPFunction_initModule(PyObject* module) {
|
||||
return false;
|
||||
Py_INCREF(&THPFunctionType);
|
||||
PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType);
|
||||
|
||||
// Initialize CDataOwner type
|
||||
if (PyType_Ready(&THPCDataOwnerType) < 0)
|
||||
return false;
|
||||
Py_INCREF(&THPCDataOwnerType);
|
||||
PyModule_AddObject(module, "_CDataOwner", (PyObject*)&THPCDataOwnerType);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -150,6 +150,9 @@ TORCH_PYTHON_API extern PyTypeObject THPFunctionType;
|
||||
TORCH_PYTHON_API extern PyObject* THPFunctionClass;
|
||||
TORCH_PYTHON_API extern PyObject* THPGradientEdgeClass;
|
||||
|
||||
// Create an ownership token that keeps the cdata (PyNode) alive
|
||||
TORCH_PYTHON_API PyObject* THPFunction_create_ownership_token(PyObject* grad_fn);
|
||||
|
||||
inline bool THPFunction_Check(PyObject* obj) {
|
||||
return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ from collections.abc import Iterator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._C._autograd as _autograd_cpp
|
||||
from torch.autograd.graph import GradientEdge, Node
|
||||
from torch.nn import Parameter
|
||||
|
||||
@ -211,6 +212,18 @@ def stage_backward_input(
|
||||
else:
|
||||
inp.grad += dinput
|
||||
|
||||
# Create ownership tokens for intermediates BEFORE detaching
|
||||
# These tokens keep the PyNode cdata alive for use in stage_backward_weight
|
||||
for param_group in param_groups:
|
||||
param_group["intermediate_tokens"] = [
|
||||
_autograd_cpp._create_ownership_token(intermediate)
|
||||
for intermediate in param_group["intermediates"]
|
||||
]
|
||||
param_group["param_tokens"] = [
|
||||
_autograd_cpp._create_ownership_token(param)
|
||||
for param in param_group["params"]
|
||||
]
|
||||
|
||||
# stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph
|
||||
# this allows autograd to clear up the graph dedicated for this tensor and free up significant memory
|
||||
for t in stage_outputs_or_loss:
|
||||
@ -238,13 +251,15 @@ def stage_backward_weight(
|
||||
valid_edges = []
|
||||
valid_grad_outputs: list[torch.Tensor] = []
|
||||
|
||||
for grads_tuple, intermediate in zip(
|
||||
param_group["grads"], param_group["intermediates"]
|
||||
for i, (grads_tuple, intermediate) in enumerate(
|
||||
zip(param_group["grads"], param_group["intermediates"])
|
||||
):
|
||||
non_none_grads = [g for g in grads_tuple if g is not None]
|
||||
if non_none_grads:
|
||||
summed_grad = sum(non_none_grads)
|
||||
valid_edges.append(GradientEdge(intermediate, 0))
|
||||
# Use pre-created ownership token to keep the intermediate node's cdata alive
|
||||
token = param_group["intermediate_tokens"][i]
|
||||
valid_edges.append(GradientEdge(intermediate, 0, ownership_token=token))
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
valid_grad_outputs.append(summed_grad)
|
||||
|
||||
@ -254,10 +269,19 @@ def stage_backward_weight(
|
||||
# because we install the hook function onto each of the intermediate autograd nodes.
|
||||
# We need to keep intermediates alive up until backward_weight, but we can free it now.
|
||||
del param_group["intermediates"]
|
||||
# Also clean up the intermediate tokens since we've copied them to valid_edges
|
||||
del param_group["intermediate_tokens"]
|
||||
|
||||
if valid_edges: # Only call autograd.grad if we have valid gradients
|
||||
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
|
||||
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
|
||||
# Use pre-created ownership tokens for weight nodes
|
||||
weights_edges = tuple(
|
||||
GradientEdge(w, 0, ownership_token=token)
|
||||
for w, token in zip(param_group["params"], param_group["param_tokens"])
|
||||
)
|
||||
# Clean up param tokens after use
|
||||
del param_group["param_tokens"]
|
||||
|
||||
dweights = torch.autograd.grad(
|
||||
valid_edges,
|
||||
weights_edges,
|
||||
|
||||
@ -1890,6 +1890,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list)
|
||||
self.unsharded_stages = set()
|
||||
|
||||
# to manage autograd graph lifetime
|
||||
self.ownership_tokens = {}
|
||||
self.ac_graph_execution_context: Optional[torch.utils.checkpoint.GraphExecGroup] = None
|
||||
|
||||
def register_custom_function(
|
||||
self,
|
||||
computation_type: _ComputationType,
|
||||
@ -2057,6 +2061,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
||||
self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
|
||||
|
||||
# TODO: we should assert its empty, and be cleaning it up as we go
|
||||
self.ownership_tokens.clear()
|
||||
|
||||
# Based on the plan in Step 1 created in __init__:
|
||||
# 2. Perform communication based on the pipeline_order
|
||||
stage_index_to_stage: dict[int, _PipelineStageBase] = {
|
||||
@ -2162,6 +2169,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
kwarg_mbs[mb_index], # type: ignore[index]
|
||||
save_forward_output=return_outputs,
|
||||
)
|
||||
key = f"{stage.stage_index}_{mb_index}"
|
||||
assert key not in self.ownership_tokens
|
||||
self.ownership_tokens[key] = output.view_as(output).grad_fn
|
||||
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
||||
|
||||
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
|
||||
@ -2213,12 +2223,14 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
)
|
||||
_wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index)))
|
||||
loss = self._maybe_get_loss(stage, mb_index)
|
||||
stage.backward_one_chunk(
|
||||
mb_index,
|
||||
loss=loss,
|
||||
full_backward=False,
|
||||
last_backward=False,
|
||||
)
|
||||
self.ac_graph_exec_group = torch.utils.checkpoint.GraphExecGroup()
|
||||
with self.ac_graph_exec_group:
|
||||
stage.backward_one_chunk(
|
||||
mb_index,
|
||||
loss=loss,
|
||||
full_backward=False,
|
||||
last_backward=False,
|
||||
)
|
||||
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
|
||||
# see [Note: V-schedule special case]
|
||||
if is_prev_stage_on_this_rank:
|
||||
@ -2229,10 +2241,15 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
self._assert_unsharded(stage)
|
||||
self.backward_counter[stage_idx] += 1
|
||||
last_backward = self.backward_counter[stage_idx] == self._n_microbatches
|
||||
stage.backward_weight_one_chunk(
|
||||
mb_index,
|
||||
last_backward=last_backward,
|
||||
)
|
||||
key = f"{stage.stage_index}_{mb_index}"
|
||||
assert key in self.ownership_tokens
|
||||
assert self.ac_graph_exec_group, "expect dI to be executed before dW"
|
||||
with self.ac_graph_exec_group:
|
||||
stage.backward_weight_one_chunk(
|
||||
mb_index,
|
||||
last_backward=last_backward,
|
||||
)
|
||||
del self.ownership_tokens[key]
|
||||
elif comp_type == REDUCE_GRAD:
|
||||
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
|
||||
stage.perform_reduce_grad(grad_scale_factor)
|
||||
@ -2272,9 +2289,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
time_step,
|
||||
action,
|
||||
)
|
||||
# TODO(whc) what is the best practice for printing a multiline log?
|
||||
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
|
||||
print(
|
||||
logger.error(
|
||||
_format_pipeline_order(
|
||||
self.pipeline_order_with_comms, # type: ignore[arg-type]
|
||||
error_step_number=time_step,
|
||||
|
||||
Reference in New Issue
Block a user