Compare commits

...

2 Commits

Author SHA1 Message Date
9270edd17c Update on "add recompute tags (from AC) into GraphModule.print_readable() by default"
This PR tweaks `GraphModule.print_readable()` to print `node.meta['recompute']` and `node.meta['ac_graph_id']` by default. The main benefit is that anyone using torch.compile with activation checkpointing will now see recompute metadata in tlparse by default when they inspect the generated joint graph.

The other options I thought about were:

(1) print all of the metadata in each node in the joint graph by default under compile. There are many things in the joint graph metadata, and I was worried about this making the joint graph too noisy

(2) add a new kwarg to `print_readable()` to print this AC-related metadata. We could do this instead, although there are already a number of kwargs in this function, and we probably don't want too many specialized kwargs for various bits of metadata

(3) the current option (print by default). I thought this was the cleanest, mainly because this metadata should only show up in nodes when you are compiling AC anyway, and it is unlikely to affect users of vanilla FX outside of compile.

Example new tlparse: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/59f4e099-b827-4bc5-b828-ebb91ed99ee6/custom/-_0_0_0/aot_joint_graph_1.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
```
class inner_f(torch.nn.Module):
    def forward(
        self,
        primals,
        tangents,
    ):
        primals_1: "f32[4, 4][4, 1]cuda:0"  # PlainAOTInput(idx=0)
        primals_2: "f32[4, 4][4, 1]cuda:0"  # PlainAOTInput(idx=1)
        tangents_1: "f32[4, 4][4, 1]cuda:0"  # TangentAOTInput(output=PlainAOTOutput(idx=0))
        primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # File: /data/users/hirsheybar/new2/pytorch/tmp5.py:9 in fn, code: gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
        
        # File: /data/users/hirsheybar/new2/pytorch/tmp5.py:4 in gn, code: return torch.sigmoid(torch.matmul(x, y))
        # recompute: 'CheckpointPolicy.PREFER_RECOMPUTE', ac_graph_id: '2'
        mm: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.mm.default(sin, primals_2)
        # recompute: 'CheckpointPolicy.PREFER_RECOMPUTE', ac_graph_id: '2'
        sigmoid: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sigmoid.default(mm);  mm = None
        # recompute: 'CheckpointPolicy.PREFER_RECOMPUTE', ac_graph_id: '2'
        detach: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.detach.default(sigmoid)
        detach_1: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.detach.default(detach);  detach = None
        sigmoid_backward: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sigmoid_backward.default(tangents_1, detach_1);  tangents_1 = detach_1 = None
        t: "f32[4, 4][1, 4]cuda:0" = torch.ops.aten.t.default(sin);  sin = None
        mm_1: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.mm.default(t, sigmoid_backward);  t = None
        t_1: "f32[4, 4][1, 4]cuda:0" = torch.ops.aten.t.default(primals_2);  primals_2 = None
        mm_2: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.mm.default(sigmoid_backward, t_1);  sigmoid_backward = t_1 = None
        
        # File: /data/users/hirsheybar/new2/pytorch/tmp5.py:9 in fn, code: gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
        cos: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.cos.default(primals_1);  primals_1 = None
        mul: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.mul.Tensor(mm_2, cos);  mm_2 = cos = None
        return pytree.tree_unflatten([
            sigmoid,  # PlainAOTOutput(idx=0)
            mul,  # GradAOTOutput(grad_of=PlainAOTInput(idx=0))
            mm_1,  # GradAOTOutput(grad_of=PlainAOTInput(idx=1))
        ], self._out_spec)
        
```





cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-13 09:48:12 -08:00
af26143dee add recompute tags (from AC) into GraphModule.print_readable() by default
[ghstack-poisoned]
2025-11-13 08:43:59 -08:00
2 changed files with 39 additions and 1 deletions

View File

@ -15,7 +15,7 @@ import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from functorch.compile import min_cut_rematerialization_partition, nop
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
@ -339,6 +339,29 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_checkpoint_shows_tags_in_tlparse(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
def partition_fn(joint_gm, *args, **kwargs):
gm_str = joint_gm.print_readable(print_output=False)
self.assertTrue("# ac_graph_id: 2 - PREFER_RECOMPUTE" in gm_str)
return min_cut_rematerialization_partition(joint_gm, *args, **kwargs)
backend = aot_autograd(
fw_compiler=nop, bw_compiler=nop, partition_fn=partition_fn
)
_ = torch.compile(fn, backend=backend)(x, y)
@requires_cuda_and_triton
def test_tags_sequential_layers(self, device):
def gn(x):

View File

@ -729,6 +729,21 @@ class CodeGen:
f"{k}: {pprint.pformat(str(v), width=80, compact=True)}\n"
)
body.append('"""\n')
elif hasattr(node, "meta") and node.meta:
# recompute tags are generated by torch.compile and put in the joint graph.
# These tags are load bearing enough that we want them to show up by default
# in tlparse, when you run torch.compile.
recompute = node.meta.get("recompute", None)
ac_graph_id = node.meta.get("ac_graph_id", None)
if recompute is not None and ac_graph_id is not None:
body.append(
f"# ac_graph_id: {str(ac_graph_id)} - {str(recompute.name)}\n"
)
elif recompute is not None:
body.append(f"# recompute: {str(recompute.name)}\n")
elif ac_graph_id is not None:
body.append(f"# ac_graph_id: {str(ac_graph_id)}\n")
if node.op == "placeholder":
assert isinstance(node.target, str)