Compare commits

...

5 Commits

Author SHA1 Message Date
aeca99aecc [do not review] Use GraphExecGroup in PP
ghstack-source-id: 5d4cbf507715d077a64428cfbb0b149c9132bbed
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/167251
2025-11-14 13:29:18 -08:00
8b6f8966ce [do not review] fix cdata for pp
ghstack-source-id: 7aca872ecae8b4960b78dd0560ba910904da4fe8
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/167006
2025-11-14 13:29:18 -08:00
76379a71ac Support AC in default partitioner when functionalization is enabled
ghstack-source-id: 20f571b5ed99ec6d1b1fe86c6c6e369cb1655e72
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/166610
2025-11-14 13:29:18 -08:00
85a07dd185 NOT FOR LAND - simon's patch 2025-11-14 13:29:16 -08:00
9650b7407b [Pipelining]
Minor logging fix

It does make the logging wider but its better than having the lines
interspersed with unrelated lines due to mixed use of print and logging.

ghstack-source-id: 725c44488f3e5a054bdf2399e3e09a04e2d847c4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167668
2025-11-14 13:28:17 -08:00
12 changed files with 643 additions and 221 deletions

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
@ -357,7 +356,6 @@ class DTensorExportTest(TestCase):
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)

View File

@ -15,7 +15,7 @@ import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from functorch.compile import default_partition, min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
)
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -281,7 +281,14 @@ class ActivationCheckpointingViaTagsTests(
run(export_compiler)
def test_tags_function(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -297,11 +304,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_via_global_checkpoint(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -316,17 +334,28 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_with_kwargs(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_with_kwargs(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
gn, torch.sin(x), y, use_reentrant=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
@ -336,11 +365,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_sequential_layers(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_sequential_layers(self, device, partition_fn):
def gn(x):
x = x.cos()
for _ in range(3):
@ -361,11 +401,22 @@ class ActivationCheckpointingViaTagsTests(
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_multiple_checkpoints(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_multiple_checkpoints(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -383,11 +434,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_module(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_module(self, device, partition_fn):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -411,11 +473,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_decomps(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_decomps(self, device, partition_fn):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
@ -443,6 +516,7 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
@ -702,7 +776,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_recompute(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
@ -723,9 +804,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def _test(context_fn, bw_compiler):
def _test(context_fn, bw_compiler, partition_fn):
def gn(x):
return torch.sigmoid(torch.matmul(x, x))
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
@ -739,14 +820,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
fw_compiler = functools.partial(
count_ops,
freq=1,
freq=2,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@ -754,17 +835,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=2, # 2 bwd mm ops per fwd matmul
freq=4, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
def test_sac_with_partial_context_fn(self):
@ -801,7 +884,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm(
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -841,15 +933,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
@ -889,7 +988,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
@ -897,7 +996,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_triton_kernel(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
# Copy of the above test, but make sure that having a triton kernel in the
# region does not error.
def add_one(x):
@ -957,14 +1063,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_tensor_subclass(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1007,14 +1120,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_custom_rule(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1072,14 +1192,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
@ -1118,14 +1245,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_outplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1163,14 +1297,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
@ -1206,7 +1347,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1217,7 +1358,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda_and_triton
def test_compile_selective_checkpoint_inplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1257,7 +1405,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1265,7 +1413,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
@ -1312,7 +1467,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
@ -1324,7 +1479,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_invalid_context(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
@ -1353,7 +1515,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
@ -1362,7 +1524,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_compile_selective_checkpoint_parametrization(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
@ -1425,7 +1594,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bw_compiler = functools.partial(
count_ops,
freqs=[
2, # 1 from mul recompute, 1 from mul backward
# 1 from mul recompute, 1 from mul backward
# w/o CSE, we have one extra mul
3 if partition_fn is default_partition else 2,
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
@ -1434,7 +1605,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
model = MLPModule()

View File

@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b)
return FwBwMutation.apply(a, b).sin_().clone()
inps = [
torch.ones(3, 3, requires_grad=True),
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""",
clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, tangents_1):
def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_2, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -911,8 +911,8 @@ class GraphModule(torch.nn.Module):
op="call_function", target=torch.ops.aten.mm.default
)
self.assertEqual(len(mm_nodes), 4)
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)

View File

@ -10,6 +10,7 @@ This file contains utilities related to functionalization in AOTAutograd:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor
@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg):
# Returns the number of detected copy_
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]:
allowed_mutation_ops = [
torch.ops.aten.copy_.default,
torch.ops.aten.set_.source_Tensor,
@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
# NB: It would also be nice to verify that the mutations all happen at the
# end, but we also do some administrative views after mutations so this
# isn't actually true. (TODO: Could this cause problems for Inductor?)
error = None
for n in fx_g.nodes:
if n.op == "placeholder":
placeholders.add(n)
@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
# this is mostly a hack to avoid failing XLA tests.
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
if "set_buffer_donor_" not in str(n.args[0]):
assert n.args[0] in placeholders, (
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
)
if n.args[0] not in placeholders:
error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
mutation_count += 1
else:
assert not n.target._schema.is_mutable, (
f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
)
if n.target._schema.is_mutable:
error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
return error, mutation_count
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
error, mutation_count = _is_functional_graph(fx_g)
assert error is None, error
return mutation_count

View File

@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker,
get_proxy_mode,
maybe_disable_thunkify,
maybe_enable_thunkify,
)
@ -295,6 +296,10 @@ def create_joint(
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals
)
mode = get_proxy_mode()
assert mode is not None, "Expected non-None proxy mode"
for node in mode.tracer.graph.nodes:
node.meta["partitioner_tag"] = "is_forward"
# TODO: I think this hook can also be eliminated now
if joint_fn_handle and joint_fn_handle.post_forward:

View File

@ -10,6 +10,7 @@ import operator
import os
import os.path
import re
import warnings
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass, replace
@ -51,6 +52,7 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import _is_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -297,6 +299,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1021,105 +1027,132 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
# Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
node.name for node in forward_nodes if node.op != "output"
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
if _is_functional_graph(joint_module.graph)[0] is not None:
# Fall-back to previous behavior to avoid bc-breaking, although can
# eventually flip the switch to make this a hard error.
warnings.warn(
"Trying to unsafely apply AC to a non-functional graph with the "
"default partitioner. Falling back to min-cut partitioner."
)
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
def is_tensor(node):
# This node returns a single tensor output
return (
"tensor_meta" in node.meta
and node.op == "call_function"
and isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
)
def is_multi_output(node):
return (
not is_tensor(node)
and all(user.target == operator.getitem for user in node.users)
and len(node.users) > 0
)
def is_impure(node):
# wait tensor is an "impure" op according to DCE's definition of impure
# (see is_impure in torch/fx/node.py), but it survives past
# functionalization and can be safely dup'd and reordered under the
# assumption SPMD.
return (
node.is_impure(impure_random=False)
and node.op
not in (
"placeholder",
"output",
)
and node.target is not torch.ops._c10d_functional.wait_tensor.default
)
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target is operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if "tensor_meta" in node.meta and all(
is_sym_node(n) for n in backward_usages
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
continue
if is_multi_output(node):
# Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE.
continue
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
saved_values.append(node)
continue
if is_impure(node):
assert not graph_has_recomputable_ops, (
"Trying to apply AC on a graph with impure op",
node,
node.target,
)
saved_values.append(node)
continue
if node.op == "call_function":
assert is_tensor(node), f"{node}"
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
continue
if not must_recompute(node):
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
return _extract_fwd_bwd_modules(
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
if static_lifetime_input_nodes is None:
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
@ -1127,6 +1160,24 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
if len(node_info.required_bw_nodes) > 0:
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6)
@ -1621,7 +1672,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
break
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
def is_getitem_of_multi_output(node):
if node.target != operator.getitem:
return False
parent = node.args[0]
return "tensor_meta" not in parent.meta and node.op == "call_function"
def cleanup_recompute_tags(
joint_module: fx.GraphModule, *, is_default_partition: bool
) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of
@ -1658,6 +1718,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
elif (
"ac_graph_id" not in node.meta
and any(must_recompute(user) for user in node.users)
and not (
# Avoid saving getitem nodes which are not labeled with "ac_graph_id"
is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta
)
and is_default_partition
):
# This node is not part of the AC region and a user is marked as recompute.
# This means it's an input to the AC region and we should save it.
# For ease of landing, gate this to default partitioner only, but we should think
# about flipping the switch in general as well.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module
@ -2765,6 +2839,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
return module
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
def min_cut_rematerialization_partition(
joint_module: fx.GraphModule,
_joint_inputs,
@ -2813,68 +2940,16 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
# networkx blows up on graphs with no required backward nodes
# Since there's nothing to partition anyway, and the default partitioner can "handle"

View File

@ -459,6 +459,22 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
}
});
m.def("_clear_callbacks", []() { at::clearCallbacks(); });
m.def(
"_create_ownership_token",
[](py::handle grad_fn) {
return py::reinterpret_steal<py::object>(
THPFunction_create_ownership_token(grad_fn.ptr()));
},
py::arg("grad_fn"),
R"(
Create an ownership token for a grad_fn that keeps its cdata alive.
Args:
grad_fn: The grad_fn (autograd.Function) to create an ownership token for.
Returns:
An ownership token object that keeps the underlying C++ Node alive.
)");
m.def(
"_saved_tensors_hooks_is_enabled",
at::SavedTensorDefaultHooks::is_enabled);

View File

@ -51,6 +51,13 @@ using at::Tensor;
PyObject* THPFunctionClass = nullptr;
PyObject* THPGradientEdgeClass = nullptr;
// CDataOwner: A simple object that holds a shared_ptr to PyNode to keep it alive
// This provides an "ownership token" for cdata
struct THPCDataOwner {
PyObject_HEAD
std::shared_ptr<PyNode> cdata;
};
#define THPFunction_assert(condition, ...) \
if (!(condition)) { \
THPUtils_setError(__VA_ARGS__); \
@ -1785,6 +1792,96 @@ static struct PyMethodDef THPFunction_methods[] = {
nullptr},
{nullptr}};
// Forward declarations for CDataOwner
static void THPCDataOwner_dealloc(THPCDataOwner* self);
static PyObject* THPCDataOwner_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs);
// CDataOwner type implementation
static void THPCDataOwner_dealloc(THPCDataOwner* self) {
self->cdata.~shared_ptr<PyNode>();
Py_TYPE(self)->tp_free((PyObject*)self);
}
static PyObject* THPCDataOwner_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs) {
PyObject* obj = type->tp_alloc(type, 0);
if (!obj)
return nullptr;
THPCDataOwner* self = (THPCDataOwner*)obj;
new (&self->cdata) std::shared_ptr<PyNode>();
return obj;
}
static PyTypeObject THPCDataOwnerType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._CDataOwner", /* tp_name */
sizeof(THPCDataOwner), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THPCDataOwner_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"Ownership token for C++ autograd Node data", /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPCDataOwner_new /* tp_new */
};
// Function to create an ownership token from a grad_fn
PyObject* THPFunction_create_ownership_token(PyObject* grad_fn) {
// Return None if not a THPFunction (e.g., C++ defined nodes don't need tokens)
if (!THPFunction_Check(grad_fn)) {
Py_RETURN_NONE;
}
auto* thp_fn = (THPFunction*)grad_fn;
auto cdata = thp_fn->cdata.lock();
TORCH_CHECK(
cdata,
"Cannot create ownership token: the underlying PyNode has already been deallocated");
PyObject* obj = THPCDataOwnerType.tp_alloc(&THPCDataOwnerType, 0);
if (!obj)
throw python_error();
THPCDataOwner* owner = (THPCDataOwner*)obj;
new (&owner->cdata) std::shared_ptr<PyNode>(std::move(cdata));
return obj;
}
PyTypeObject THPFunctionType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._FunctionBase", /* tp_name */
@ -1833,5 +1930,12 @@ bool THPFunction_initModule(PyObject* module) {
return false;
Py_INCREF(&THPFunctionType);
PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType);
// Initialize CDataOwner type
if (PyType_Ready(&THPCDataOwnerType) < 0)
return false;
Py_INCREF(&THPCDataOwnerType);
PyModule_AddObject(module, "_CDataOwner", (PyObject*)&THPCDataOwnerType);
return true;
}

View File

@ -150,6 +150,9 @@ TORCH_PYTHON_API extern PyTypeObject THPFunctionType;
TORCH_PYTHON_API extern PyObject* THPFunctionClass;
TORCH_PYTHON_API extern PyObject* THPGradientEdgeClass;
// Create an ownership token that keeps the cdata (PyNode) alive
TORCH_PYTHON_API PyObject* THPFunction_create_ownership_token(PyObject* grad_fn);
inline bool THPFunction_Check(PyObject* obj) {
return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
}

View File

@ -6,6 +6,7 @@ from collections.abc import Iterator
from typing import Any, Optional, Union
import torch
import torch._C._autograd as _autograd_cpp
from torch.autograd.graph import GradientEdge, Node
from torch.nn import Parameter
@ -211,6 +212,18 @@ def stage_backward_input(
else:
inp.grad += dinput
# Create ownership tokens for intermediates BEFORE detaching
# These tokens keep the PyNode cdata alive for use in stage_backward_weight
for param_group in param_groups:
param_group["intermediate_tokens"] = [
_autograd_cpp._create_ownership_token(intermediate)
for intermediate in param_group["intermediates"]
]
param_group["param_tokens"] = [
_autograd_cpp._create_ownership_token(param)
for param in param_group["params"]
]
# stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph
# this allows autograd to clear up the graph dedicated for this tensor and free up significant memory
for t in stage_outputs_or_loss:
@ -238,13 +251,15 @@ def stage_backward_weight(
valid_edges = []
valid_grad_outputs: list[torch.Tensor] = []
for grads_tuple, intermediate in zip(
param_group["grads"], param_group["intermediates"]
for i, (grads_tuple, intermediate) in enumerate(
zip(param_group["grads"], param_group["intermediates"])
):
non_none_grads = [g for g in grads_tuple if g is not None]
if non_none_grads:
summed_grad = sum(non_none_grads)
valid_edges.append(GradientEdge(intermediate, 0))
# Use pre-created ownership token to keep the intermediate node's cdata alive
token = param_group["intermediate_tokens"][i]
valid_edges.append(GradientEdge(intermediate, 0, ownership_token=token))
# pyrefly: ignore [bad-argument-type]
valid_grad_outputs.append(summed_grad)
@ -254,10 +269,19 @@ def stage_backward_weight(
# because we install the hook function onto each of the intermediate autograd nodes.
# We need to keep intermediates alive up until backward_weight, but we can free it now.
del param_group["intermediates"]
# Also clean up the intermediate tokens since we've copied them to valid_edges
del param_group["intermediate_tokens"]
if valid_edges: # Only call autograd.grad if we have valid gradients
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
# Use pre-created ownership tokens for weight nodes
weights_edges = tuple(
GradientEdge(w, 0, ownership_token=token)
for w, token in zip(param_group["params"], param_group["param_tokens"])
)
# Clean up param tokens after use
del param_group["param_tokens"]
dweights = torch.autograd.grad(
valid_edges,
weights_edges,

View File

@ -1890,6 +1890,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list)
self.unsharded_stages = set()
# to manage autograd graph lifetime
self.ownership_tokens = {}
self.ac_graph_execution_context: Optional[torch.utils.checkpoint.GraphExecGroup] = None
def register_custom_function(
self,
computation_type: _ComputationType,
@ -2057,6 +2061,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
# TODO: we should assert its empty, and be cleaning it up as we go
self.ownership_tokens.clear()
# Based on the plan in Step 1 created in __init__:
# 2. Perform communication based on the pipeline_order
stage_index_to_stage: dict[int, _PipelineStageBase] = {
@ -2162,6 +2169,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
kwarg_mbs[mb_index], # type: ignore[index]
save_forward_output=return_outputs,
)
key = f"{stage.stage_index}_{mb_index}"
assert key not in self.ownership_tokens
self.ownership_tokens[key] = output.view_as(output).grad_fn
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
@ -2213,12 +2223,14 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
)
_wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index)))
loss = self._maybe_get_loss(stage, mb_index)
stage.backward_one_chunk(
mb_index,
loss=loss,
full_backward=False,
last_backward=False,
)
self.ac_graph_exec_group = torch.utils.checkpoint.GraphExecGroup()
with self.ac_graph_exec_group:
stage.backward_one_chunk(
mb_index,
loss=loss,
full_backward=False,
last_backward=False,
)
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
# see [Note: V-schedule special case]
if is_prev_stage_on_this_rank:
@ -2229,10 +2241,15 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
self._assert_unsharded(stage)
self.backward_counter[stage_idx] += 1
last_backward = self.backward_counter[stage_idx] == self._n_microbatches
stage.backward_weight_one_chunk(
mb_index,
last_backward=last_backward,
)
key = f"{stage.stage_index}_{mb_index}"
assert key in self.ownership_tokens
assert self.ac_graph_exec_group, "expect dI to be executed before dW"
with self.ac_graph_exec_group:
stage.backward_weight_one_chunk(
mb_index,
last_backward=last_backward,
)
del self.ownership_tokens[key]
elif comp_type == REDUCE_GRAD:
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
stage.perform_reduce_grad(grad_scale_factor)
@ -2272,9 +2289,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
time_step,
action,
)
# TODO(whc) what is the best practice for printing a multiline log?
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
print(
logger.error(
_format_pipeline_order(
self.pipeline_order_with_comms, # type: ignore[arg-type]
error_step_number=time_step,