Compare commits

...

5 Commits

Author SHA1 Message Date
aeca99aecc [do not review] Use GraphExecGroup in PP
ghstack-source-id: 5d4cbf507715d077a64428cfbb0b149c9132bbed
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/167251
2025-11-14 13:29:18 -08:00
8b6f8966ce [do not review] fix cdata for pp
ghstack-source-id: 7aca872ecae8b4960b78dd0560ba910904da4fe8
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/167006
2025-11-14 13:29:18 -08:00
76379a71ac Support AC in default partitioner when functionalization is enabled
ghstack-source-id: 20f571b5ed99ec6d1b1fe86c6c6e369cb1655e72
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/166610
2025-11-14 13:29:18 -08:00
85a07dd185 NOT FOR LAND - simon's patch 2025-11-14 13:29:16 -08:00
9650b7407b [Pipelining]
Minor logging fix

It does make the logging wider but its better than having the lines
interspersed with unrelated lines due to mixed use of print and logging.

ghstack-source-id: 725c44488f3e5a054bdf2399e3e09a04e2d847c4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167668
2025-11-14 13:28:17 -08:00
12 changed files with 643 additions and 221 deletions

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import contextlib import contextlib
import unittest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -357,7 +356,6 @@ class DTensorExportTest(TestCase):
# aot_export_joint_with_descriptors on strict-exported exported_program.module() # aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing # is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self): def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors) self._run_test(strict_export_and_aot_export_joint_with_descriptors)

View File

@ -15,7 +15,7 @@ import torch._functorch.config
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition from functorch.compile import default_partition, min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import ( from torch._dynamo.testing import (
AotEagerAndRecordGraphs, AotEagerAndRecordGraphs,
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
) )
from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor from torch.testing._internal.two_tensor import TwoTensor
@ -281,7 +281,14 @@ class ActivationCheckpointingViaTagsTests(
run(export_compiler) run(export_compiler)
def test_tags_function(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function(self, device, partition_fn):
def gn(x, y): def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) return torch.sigmoid(torch.matmul(x, y))
@ -297,11 +304,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial( bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd ) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
def test_tags_function_via_global_checkpoint(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
def gn(x, y): def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) return torch.sigmoid(torch.matmul(x, y))
@ -316,17 +334,28 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial( bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd ) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
def test_tags_function_with_kwargs(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_with_kwargs(self, device, partition_fn):
def gn(x, y): def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) return torch.sigmoid(torch.matmul(x, y))
def fn(x, y): def fn(x, y):
return torch.utils.checkpoint.checkpoint( return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False gn, torch.sin(x), y, use_reentrant=False
) )
x = torch.randn(4, 4, device=device, requires_grad=True) x = torch.randn(4, 4, device=device, requires_grad=True)
@ -336,11 +365,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial( bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd ) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
def test_tags_sequential_layers(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_sequential_layers(self, device, partition_fn):
def gn(x): def gn(x):
x = x.cos() x = x.cos()
for _ in range(3): for _ in range(3):
@ -361,11 +401,22 @@ class ActivationCheckpointingViaTagsTests(
freqs=[2, 18], freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd ) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x) self._validate(fn, backend, x)
@requires_cuda_and_triton @requires_cuda_and_triton
def test_tags_multiple_checkpoints(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_multiple_checkpoints(self, device, partition_fn):
def gn(x, y): def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) return torch.sigmoid(torch.matmul(x, y))
@ -383,11 +434,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial( bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd ) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
def test_tags_module(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_module(self, device, partition_fn):
class MockModule(torch.nn.Module): class MockModule(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -411,11 +473,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial( bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default count_ops, freq=1, op=torch.ops.aten.sigmoid.default
) )
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x) self._validate(fn, backend, x)
@requires_cuda_and_triton @requires_cuda_and_triton
def test_tags_decomps(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_decomps(self, device, partition_fn):
# Ensures that tags are passed on through decompositions as well # Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module): class MockModule(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -443,6 +516,7 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=lambda: import_module( decompositions=lambda: import_module(
"torch._inductor.compile_fx" "torch._inductor.compile_fx"
).select_decomp_table(), ).select_decomp_table(),
@ -702,7 +776,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_recompute(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
def context_fn_must_recompute_mm(): def context_fn_must_recompute_mm():
must_recompute_list = [ must_recompute_list = [
torch.ops.aten.mm.default, torch.ops.aten.mm.default,
@ -723,9 +804,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
), ),
) )
def _test(context_fn, bw_compiler): def _test(context_fn, bw_compiler, partition_fn):
def gn(x): def gn(x):
return torch.sigmoid(torch.matmul(x, x)) return torch.cos(torch.sin(torch.matmul(x, x) @ x))
def fn(x): def fn(x):
return torch.utils.checkpoint.checkpoint( return torch.utils.checkpoint.checkpoint(
@ -739,14 +820,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
fw_compiler = functools.partial( fw_compiler = functools.partial(
count_ops, count_ops,
freq=1, freq=2,
op=torch.ops.aten.mm.default, op=torch.ops.aten.mm.default,
) )
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x) self._validate(fn, backend, x)
@ -754,17 +835,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
context_fn=context_fn_must_recompute_mm, context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial( bw_compiler=functools.partial(
count_ops, count_ops,
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3) freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
op=torch.ops.aten.mm.default, op=torch.ops.aten.mm.default,
), ),
partition_fn=partition_fn,
) )
_test( _test(
context_fn=context_fn_no_recompute_mm, context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial( bw_compiler=functools.partial(
count_ops, count_ops,
freq=2, # 2 bwd mm ops per fwd matmul freq=4, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default, op=torch.ops.aten.mm.default,
), ),
partition_fn=partition_fn,
) )
def test_sac_with_partial_context_fn(self): def test_sac_with_partial_context_fn(self):
@ -801,7 +884,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm(
self, device, partition_fn
):
def selective_checkpointing_context_fn(): def selective_checkpointing_context_fn():
no_recompute_list = [ no_recompute_list = [
torch.ops.aten.mm.default, torch.ops.aten.mm.default,
@ -841,15 +933,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization( def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device self, device, partition_fn
): ):
def selective_checkpointing_context_fn(): def selective_checkpointing_context_fn():
no_recompute_list = [ no_recompute_list = [
@ -889,7 +988,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
disable_functionalization=True, disable_functionalization=True,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
@ -897,7 +996,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_triton_kernel(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
# Copy of the above test, but make sure that having a triton kernel in the # Copy of the above test, but make sure that having a triton kernel in the
# region does not error. # region does not error.
def add_one(x): def add_one(x):
@ -957,14 +1063,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_tensor_subclass(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
def selective_checkpointing_context_fn(): def selective_checkpointing_context_fn():
no_recompute_list = [ no_recompute_list = [
torch.ops.aten.mm.default, torch.ops.aten.mm.default,
@ -1007,14 +1120,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_custom_rule(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
def _get_custom_policy(meta): def _get_custom_policy(meta):
no_recompute_list = [ no_recompute_list = [
torch.ops.aten.mm.default, torch.ops.aten.mm.default,
@ -1072,14 +1192,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_partial_ctx_fn(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
def selective_checkpointing_context_fn(no_recompute_list): def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts( return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list) _get_custom_policy(no_recompute_list=no_recompute_list)
@ -1118,14 +1245,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_outplace_op(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn(): def selective_checkpointing_context_fn():
no_recompute_list = [ no_recompute_list = [
torch.ops.aten.mm.default, torch.ops.aten.mm.default,
@ -1163,14 +1297,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
def selective_checkpointing_context_fn(): def selective_checkpointing_context_fn():
# recompute everything # recompute everything
no_recompute_list = [] no_recompute_list = []
@ -1206,7 +1347,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1217,7 +1358,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
"requires TorchDispatchMode + torch.compile work to complete" "requires TorchDispatchMode + torch.compile work to complete"
) )
@requires_cuda_and_triton @requires_cuda_and_triton
def test_compile_selective_checkpoint_inplace_op(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn(): def selective_checkpointing_context_fn():
no_recompute_list = [ no_recompute_list = [
torch.ops.aten.mm.default, torch.ops.aten.mm.default,
@ -1257,7 +1405,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
self._validate(fn, backend, x, y) self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1265,7 +1413,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True) @torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
for preserve_rng_state in [True, False]: for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn(): def selective_checkpointing_context_fn():
@ -1312,7 +1467,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
@ -1324,7 +1479,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton @requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_invalid_context(self): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
def gn(x, y): def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y return torch.sigmoid(torch.matmul(x, y)) * y
@ -1353,7 +1515,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
with self.assertRaisesRegex( with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s" Exception, "must generate a tuple of two `TorchDispatchMode`s"
@ -1362,7 +1524,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton @requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_compile_selective_checkpoint_parametrization(self): @parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
def sac_policy(): def sac_policy():
def _recomp_policy(): def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs): def _custom_policy(ctx, func, *args, **kwargs):
@ -1425,7 +1594,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bw_compiler = functools.partial( bw_compiler = functools.partial(
count_ops, count_ops,
freqs=[ freqs=[
2, # 1 from mul recompute, 1 from mul backward # 1 from mul recompute, 1 from mul backward
# w/o CSE, we have one extra mul
3 if partition_fn is default_partition else 2,
1, 1,
], ],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
@ -1434,7 +1605,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd( backend = aot_autograd(
fw_compiler=fw_compiler, fw_compiler=fw_compiler,
bw_compiler=bw_compiler, bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition, partition_fn=partition_fn,
) )
model = MLPModule() model = MLPModule()

View File

@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x return grad_output * x, grad_output * x
def f(a, b): def f(a, b):
return FwBwMutation.apply(a, b) return FwBwMutation.apply(a, b).sin_().clone()
inps = [ inps = [
torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True),
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None _foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""", clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
) )
# important bit: there is 1 mutation in the bw # important bit: there is 1 mutation in the bw
self.assertExpectedInline( self.assertExpectedInline(
bw_graph[0].code.strip(), bw_graph[0].code.strip(),
"""\ """\
def forward(self, add, tangents_1): def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None _foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_1, None)""", return (mul_2, None)""",
) )
def test_fw_bw_mutation_no_functionalization2(self): def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -911,8 +911,8 @@ class GraphModule(torch.nn.Module):
op="call_function", target=torch.ops.aten.mm.default op="call_function", target=torch.ops.aten.mm.default
) )
self.assertEqual(len(mm_nodes), 4) self.assertEqual(len(mm_nodes), 4)
self.assertNotIn("partitioner_tag", mm_nodes[0].meta) self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
self.assertNotIn("partitioner_tag", mm_nodes[1].meta) self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0) self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)

View File

@ -10,6 +10,7 @@ This file contains utilities related to functionalization in AOTAutograd:
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg):
# Returns the number of detected copy_ # Returns the number of detected copy_
def assert_functional_graph(fx_g: torch.fx.Graph) -> int: def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]:
allowed_mutation_ops = [ allowed_mutation_ops = [
torch.ops.aten.copy_.default, torch.ops.aten.copy_.default,
torch.ops.aten.set_.source_Tensor, torch.ops.aten.set_.source_Tensor,
@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
# NB: It would also be nice to verify that the mutations all happen at the # NB: It would also be nice to verify that the mutations all happen at the
# end, but we also do some administrative views after mutations so this # end, but we also do some administrative views after mutations so this
# isn't actually true. (TODO: Could this cause problems for Inductor?) # isn't actually true. (TODO: Could this cause problems for Inductor?)
error = None
for n in fx_g.nodes: for n in fx_g.nodes:
if n.op == "placeholder": if n.op == "placeholder":
placeholders.add(n) placeholders.add(n)
@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
# this is mostly a hack to avoid failing XLA tests. # this is mostly a hack to avoid failing XLA tests.
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
if "set_buffer_donor_" not in str(n.args[0]): if "set_buffer_donor_" not in str(n.args[0]):
assert n.args[0] in placeholders, ( if n.args[0] not in placeholders:
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
)
mutation_count += 1 mutation_count += 1
else: else:
assert not n.target._schema.is_mutable, ( if n.target._schema.is_mutable:
f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
) return error, mutation_count
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
error, mutation_count = _is_functional_graph(fx_g)
assert error is None, error
return mutation_count return mutation_count

View File

@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import ( from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker, _proxy_tensor_disable_update_tensor_tracker,
get_proxy_mode,
maybe_disable_thunkify, maybe_disable_thunkify,
maybe_enable_thunkify, maybe_enable_thunkify,
) )
@ -295,6 +296,10 @@ def create_joint(
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals fn, primals
) )
mode = get_proxy_mode()
assert mode is not None, "Expected non-None proxy mode"
for node in mode.tracer.graph.nodes:
node.meta["partitioner_tag"] = "is_forward"
# TODO: I think this hook can also be eliminated now # TODO: I think this hook can also be eliminated now
if joint_fn_handle and joint_fn_handle.post_forward: if joint_fn_handle and joint_fn_handle.post_forward:

View File

@ -10,6 +10,7 @@ import operator
import os import os
import os.path import os.path
import re import re
import warnings
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
@ -51,6 +52,7 @@ from ._activation_checkpointing.knapsack import (
) )
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import _is_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -297,6 +299,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward" return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool: def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward" return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1021,90 +1027,110 @@ def default_partition(
Returns: Returns:
Returns the generated forward and backward Fx graph modules. Returns the generated forward and backward Fx graph modules.
""" """
if has_recomputable_ops(joint_module): # Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
forward_node_names = OrderedSet(
node.name for node in forward_nodes if node.op != "output"
)
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
if _is_functional_graph(joint_module.graph)[0] is not None:
# Fall-back to previous behavior to avoid bc-breaking, although can
# eventually flip the switch to make this a hard error.
warnings.warn(
"Trying to unsafely apply AC to a non-functional graph with the "
"default partitioner. Falling back to min-cut partitioner."
)
return min_cut_rematerialization_partition( return min_cut_rematerialization_partition(
joint_module, joint_module,
_joint_inputs, _joint_inputs,
num_fwd_outputs=num_fwd_outputs, num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices, static_lifetime_input_indices=static_lifetime_input_indices,
) )
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( if not config.unsafe_allow_optimization_of_collectives:
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
) )
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
saved_values = [] saved_values = []
saved_sym_nodes = [] saved_sym_nodes = []
def is_mutated_later_in_fw(node): def is_tensor(node):
if _has_tag_is_backward(node): # This node returns a single tensor output
return False return (
tensor_arg_aliases = [ "tensor_meta" in node.meta
x and node.op == "call_function"
for x in node.args and isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
if isinstance(x, fx.Node) )
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor) def is_multi_output(node):
] return (
while len(tensor_arg_aliases) > 0: not is_tensor(node)
a = tensor_arg_aliases.pop() and all(user.target == operator.getitem for user in node.users)
for u in a.users: and len(node.users) > 0
if not isinstance(u.target, torch._ops.OpOverload): )
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward", def is_impure(node):
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values) # wait tensor is an "impure" op according to DCE's definition of impure
if ( # (see is_impure in torch/fx/node.py), but it survives past
# one of the args was mutated # functionalization and can be safely dup'd and reordered under the
u.target._schema.is_mutable # assumption SPMD.
# and the mutation happens "later" return (
and order[u] > order[node] node.is_impure(impure_random=False)
# and the mutation happened during the forward and node.op
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u)) not in (
): "placeholder",
for idx, alias_info in enumerate(u.target._schema.arguments): "output",
if alias_info.is_write and u.args[idx] is a: )
return True and node.target is not torch.ops._c10d_functional.wait_tensor.default
elif u.target.is_view: )
tensor_arg_aliases.append(u)
return False
for node in joint_module.graph.nodes: for node in joint_module.graph.nodes:
if node.name not in forward_node_names: if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue continue
if is_sym_node(node): if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls # Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx # save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node) saved_sym_nodes.append(node)
elif ( continue
"tensor_meta" not in node.meta if is_multi_output(node):
and node.op == "call_function" # Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE.
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) continue
): if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
# Since we can't save tuple of tensor values, we need to flatten out what we're saving saved_values.append(node)
users = node.users continue
assert all(user.target is operator.getitem for user in users) if is_impure(node):
saved_values.extend(users) assert not graph_has_recomputable_ops, (
else: "Trying to apply AC on a graph with impure op",
backward_usages = [ node,
n for n in node.users if n.name not in forward_node_names node.target,
] )
if "tensor_meta" in node.meta and all( saved_values.append(node)
is_sym_node(n) for n in backward_usages continue
): if node.op == "call_function":
assert is_tensor(node), f"{node}"
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward, # If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data, # and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
@ -1114,12 +1140,19 @@ def default_partition(
# then we would be obligated to clone the input before saving it to appease autograd. # then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug). # (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages) saved_sym_nodes.extend(backward_usages)
else: continue
if not must_recompute(node):
saved_values.append(node) saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys()) saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
return _extract_fwd_bwd_modules( if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
if static_lifetime_input_nodes is None:
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module, joint_module,
saved_values, saved_values,
saved_sym_nodes=saved_sym_nodes, saved_sym_nodes=saved_sym_nodes,
@ -1127,6 +1160,24 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes, static_lifetime_input_nodes=static_lifetime_input_nodes,
) )
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
if len(node_info.required_bw_nodes) > 0:
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6) INT_INF = int(1e6)
@ -1621,7 +1672,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
break break
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: def is_getitem_of_multi_output(node):
if node.target != operator.getitem:
return False
parent = node.args[0]
return "tensor_meta" not in parent.meta and node.op == "call_function"
def cleanup_recompute_tags(
joint_module: fx.GraphModule, *, is_default_partition: bool
) -> fx.GraphModule:
""" """
If there are two consecutive checkpointed blocks with no operator in If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of between, we would still want to stash the tensor at the boundary of
@ -1658,6 +1718,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out` # Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency. # in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
elif (
"ac_graph_id" not in node.meta
and any(must_recompute(user) for user in node.users)
and not (
# Avoid saving getitem nodes which are not labeled with "ac_graph_id"
is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta
)
and is_default_partition
):
# This node is not part of the AC region and a user is marked as recompute.
# This means it's an input to the AC region and we should save it.
# For ease of landing, gate this to default partitioner only, but we should think
# about flipping the switch in general as well.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module return joint_module
@ -2765,6 +2839,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
return module return module
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
def min_cut_rematerialization_partition( def min_cut_rematerialization_partition(
joint_module: fx.GraphModule, joint_module: fx.GraphModule,
_joint_inputs, _joint_inputs,
@ -2813,68 +2940,16 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_ops = has_recomputable_ops(joint_module) graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops: if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module) joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
if not config.unsafe_allow_optimization_of_collectives: if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module) force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module) force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
if static_lifetime_input_indices is None: if static_lifetime_input_indices is None:
static_lifetime_input_indices = [] static_lifetime_input_indices = []
node_info = classify_nodes(joint_module, static_lifetime_input_indices) node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
# networkx blows up on graphs with no required backward nodes # networkx blows up on graphs with no required backward nodes
# Since there's nothing to partition anyway, and the default partitioner can "handle" # Since there's nothing to partition anyway, and the default partitioner can "handle"

View File

@ -459,6 +459,22 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
} }
}); });
m.def("_clear_callbacks", []() { at::clearCallbacks(); }); m.def("_clear_callbacks", []() { at::clearCallbacks(); });
m.def(
"_create_ownership_token",
[](py::handle grad_fn) {
return py::reinterpret_steal<py::object>(
THPFunction_create_ownership_token(grad_fn.ptr()));
},
py::arg("grad_fn"),
R"(
Create an ownership token for a grad_fn that keeps its cdata alive.
Args:
grad_fn: The grad_fn (autograd.Function) to create an ownership token for.
Returns:
An ownership token object that keeps the underlying C++ Node alive.
)");
m.def( m.def(
"_saved_tensors_hooks_is_enabled", "_saved_tensors_hooks_is_enabled",
at::SavedTensorDefaultHooks::is_enabled); at::SavedTensorDefaultHooks::is_enabled);

View File

@ -51,6 +51,13 @@ using at::Tensor;
PyObject* THPFunctionClass = nullptr; PyObject* THPFunctionClass = nullptr;
PyObject* THPGradientEdgeClass = nullptr; PyObject* THPGradientEdgeClass = nullptr;
// CDataOwner: A simple object that holds a shared_ptr to PyNode to keep it alive
// This provides an "ownership token" for cdata
struct THPCDataOwner {
PyObject_HEAD
std::shared_ptr<PyNode> cdata;
};
#define THPFunction_assert(condition, ...) \ #define THPFunction_assert(condition, ...) \
if (!(condition)) { \ if (!(condition)) { \
THPUtils_setError(__VA_ARGS__); \ THPUtils_setError(__VA_ARGS__); \
@ -1785,6 +1792,96 @@ static struct PyMethodDef THPFunction_methods[] = {
nullptr}, nullptr},
{nullptr}}; {nullptr}};
// Forward declarations for CDataOwner
static void THPCDataOwner_dealloc(THPCDataOwner* self);
static PyObject* THPCDataOwner_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs);
// CDataOwner type implementation
static void THPCDataOwner_dealloc(THPCDataOwner* self) {
self->cdata.~shared_ptr<PyNode>();
Py_TYPE(self)->tp_free((PyObject*)self);
}
static PyObject* THPCDataOwner_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs) {
PyObject* obj = type->tp_alloc(type, 0);
if (!obj)
return nullptr;
THPCDataOwner* self = (THPCDataOwner*)obj;
new (&self->cdata) std::shared_ptr<PyNode>();
return obj;
}
static PyTypeObject THPCDataOwnerType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._CDataOwner", /* tp_name */
sizeof(THPCDataOwner), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THPCDataOwner_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"Ownership token for C++ autograd Node data", /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPCDataOwner_new /* tp_new */
};
// Function to create an ownership token from a grad_fn
PyObject* THPFunction_create_ownership_token(PyObject* grad_fn) {
// Return None if not a THPFunction (e.g., C++ defined nodes don't need tokens)
if (!THPFunction_Check(grad_fn)) {
Py_RETURN_NONE;
}
auto* thp_fn = (THPFunction*)grad_fn;
auto cdata = thp_fn->cdata.lock();
TORCH_CHECK(
cdata,
"Cannot create ownership token: the underlying PyNode has already been deallocated");
PyObject* obj = THPCDataOwnerType.tp_alloc(&THPCDataOwnerType, 0);
if (!obj)
throw python_error();
THPCDataOwner* owner = (THPCDataOwner*)obj;
new (&owner->cdata) std::shared_ptr<PyNode>(std::move(cdata));
return obj;
}
PyTypeObject THPFunctionType = { PyTypeObject THPFunctionType = {
PyVarObject_HEAD_INIT(nullptr, 0) PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._FunctionBase", /* tp_name */ "torch._C._FunctionBase", /* tp_name */
@ -1833,5 +1930,12 @@ bool THPFunction_initModule(PyObject* module) {
return false; return false;
Py_INCREF(&THPFunctionType); Py_INCREF(&THPFunctionType);
PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType); PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType);
// Initialize CDataOwner type
if (PyType_Ready(&THPCDataOwnerType) < 0)
return false;
Py_INCREF(&THPCDataOwnerType);
PyModule_AddObject(module, "_CDataOwner", (PyObject*)&THPCDataOwnerType);
return true; return true;
} }

View File

@ -150,6 +150,9 @@ TORCH_PYTHON_API extern PyTypeObject THPFunctionType;
TORCH_PYTHON_API extern PyObject* THPFunctionClass; TORCH_PYTHON_API extern PyObject* THPFunctionClass;
TORCH_PYTHON_API extern PyObject* THPGradientEdgeClass; TORCH_PYTHON_API extern PyObject* THPGradientEdgeClass;
// Create an ownership token that keeps the cdata (PyNode) alive
TORCH_PYTHON_API PyObject* THPFunction_create_ownership_token(PyObject* grad_fn);
inline bool THPFunction_Check(PyObject* obj) { inline bool THPFunction_Check(PyObject* obj) {
return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType); return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
} }

View File

@ -6,6 +6,7 @@ from collections.abc import Iterator
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
import torch._C._autograd as _autograd_cpp
from torch.autograd.graph import GradientEdge, Node from torch.autograd.graph import GradientEdge, Node
from torch.nn import Parameter from torch.nn import Parameter
@ -211,6 +212,18 @@ def stage_backward_input(
else: else:
inp.grad += dinput inp.grad += dinput
# Create ownership tokens for intermediates BEFORE detaching
# These tokens keep the PyNode cdata alive for use in stage_backward_weight
for param_group in param_groups:
param_group["intermediate_tokens"] = [
_autograd_cpp._create_ownership_token(intermediate)
for intermediate in param_group["intermediates"]
]
param_group["param_tokens"] = [
_autograd_cpp._create_ownership_token(param)
for param in param_group["params"]
]
# stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph
# this allows autograd to clear up the graph dedicated for this tensor and free up significant memory # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory
for t in stage_outputs_or_loss: for t in stage_outputs_or_loss:
@ -238,13 +251,15 @@ def stage_backward_weight(
valid_edges = [] valid_edges = []
valid_grad_outputs: list[torch.Tensor] = [] valid_grad_outputs: list[torch.Tensor] = []
for grads_tuple, intermediate in zip( for i, (grads_tuple, intermediate) in enumerate(
param_group["grads"], param_group["intermediates"] zip(param_group["grads"], param_group["intermediates"])
): ):
non_none_grads = [g for g in grads_tuple if g is not None] non_none_grads = [g for g in grads_tuple if g is not None]
if non_none_grads: if non_none_grads:
summed_grad = sum(non_none_grads) summed_grad = sum(non_none_grads)
valid_edges.append(GradientEdge(intermediate, 0)) # Use pre-created ownership token to keep the intermediate node's cdata alive
token = param_group["intermediate_tokens"][i]
valid_edges.append(GradientEdge(intermediate, 0, ownership_token=token))
# pyrefly: ignore [bad-argument-type] # pyrefly: ignore [bad-argument-type]
valid_grad_outputs.append(summed_grad) valid_grad_outputs.append(summed_grad)
@ -254,10 +269,19 @@ def stage_backward_weight(
# because we install the hook function onto each of the intermediate autograd nodes. # because we install the hook function onto each of the intermediate autograd nodes.
# We need to keep intermediates alive up until backward_weight, but we can free it now. # We need to keep intermediates alive up until backward_weight, but we can free it now.
del param_group["intermediates"] del param_group["intermediates"]
# Also clean up the intermediate tokens since we've copied them to valid_edges
del param_group["intermediate_tokens"]
if valid_edges: # Only call autograd.grad if we have valid gradients if valid_edges: # Only call autograd.grad if we have valid gradients
# [NEW!] Able to pass a GradientEdge to autograd.grad as output # [NEW!] Able to pass a GradientEdge to autograd.grad as output
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) # Use pre-created ownership tokens for weight nodes
weights_edges = tuple(
GradientEdge(w, 0, ownership_token=token)
for w, token in zip(param_group["params"], param_group["param_tokens"])
)
# Clean up param tokens after use
del param_group["param_tokens"]
dweights = torch.autograd.grad( dweights = torch.autograd.grad(
valid_edges, valid_edges,
weights_edges, weights_edges,

View File

@ -1890,6 +1890,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list) self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list)
self.unsharded_stages = set() self.unsharded_stages = set()
# to manage autograd graph lifetime
self.ownership_tokens = {}
self.ac_graph_execution_context: Optional[torch.utils.checkpoint.GraphExecGroup] = None
def register_custom_function( def register_custom_function(
self, self,
computation_type: _ComputationType, computation_type: _ComputationType,
@ -2057,6 +2061,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
# TODO: we should assert its empty, and be cleaning it up as we go
self.ownership_tokens.clear()
# Based on the plan in Step 1 created in __init__: # Based on the plan in Step 1 created in __init__:
# 2. Perform communication based on the pipeline_order # 2. Perform communication based on the pipeline_order
stage_index_to_stage: dict[int, _PipelineStageBase] = { stage_index_to_stage: dict[int, _PipelineStageBase] = {
@ -2162,6 +2169,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
kwarg_mbs[mb_index], # type: ignore[index] kwarg_mbs[mb_index], # type: ignore[index]
save_forward_output=return_outputs, save_forward_output=return_outputs,
) )
key = f"{stage.stage_index}_{mb_index}"
assert key not in self.ownership_tokens
self.ownership_tokens[key] = output.view_as(output).grad_fn
self._maybe_compute_loss(stage, output, target_mbs, mb_index) self._maybe_compute_loss(stage, output, target_mbs, mb_index)
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
@ -2213,6 +2223,8 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
) )
_wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index))) _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index)))
loss = self._maybe_get_loss(stage, mb_index) loss = self._maybe_get_loss(stage, mb_index)
self.ac_graph_exec_group = torch.utils.checkpoint.GraphExecGroup()
with self.ac_graph_exec_group:
stage.backward_one_chunk( stage.backward_one_chunk(
mb_index, mb_index,
loss=loss, loss=loss,
@ -2229,10 +2241,15 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
self._assert_unsharded(stage) self._assert_unsharded(stage)
self.backward_counter[stage_idx] += 1 self.backward_counter[stage_idx] += 1
last_backward = self.backward_counter[stage_idx] == self._n_microbatches last_backward = self.backward_counter[stage_idx] == self._n_microbatches
key = f"{stage.stage_index}_{mb_index}"
assert key in self.ownership_tokens
assert self.ac_graph_exec_group, "expect dI to be executed before dW"
with self.ac_graph_exec_group:
stage.backward_weight_one_chunk( stage.backward_weight_one_chunk(
mb_index, mb_index,
last_backward=last_backward, last_backward=last_backward,
) )
del self.ownership_tokens[key]
elif comp_type == REDUCE_GRAD: elif comp_type == REDUCE_GRAD:
grad_scale_factor = self._n_microbatches if self.scale_grads else 1 grad_scale_factor = self._n_microbatches if self.scale_grads else 1
stage.perform_reduce_grad(grad_scale_factor) stage.perform_reduce_grad(grad_scale_factor)
@ -2272,9 +2289,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
time_step, time_step,
action, action,
) )
# TODO(whc) what is the best practice for printing a multiline log? logger.error(
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
print(
_format_pipeline_order( _format_pipeline_order(
self.pipeline_order_with_comms, # type: ignore[arg-type] self.pipeline_order_with_comms, # type: ignore[arg-type]
error_step_number=time_step, error_step_number=time_step,