Compare commits

...

1 Commits

Author SHA1 Message Date
f60c409a0a Add tests for aot_export_joint_with_descriptors annotation 2025-09-25 13:17:17 -07:00

View File

@ -9,6 +9,7 @@
from contextlib import ExitStack
import torch
import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._decomp import decomposition_table
@ -761,6 +762,52 @@ class inner_f(torch.nn.Module):
compiled_fn(*dict(model.named_parameters()).values(), inputs).sum().backward()
self.assertIsNotNone(model.linear.weight.grad)
def test_preserve_annotate_simple(self):
"""Test basic linear module with aot_export_joint_with_descriptors"""
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
y = self.linear(x)
return y - 1
inputs = (torch.randn(4, 3),)
for with_export in [True, False]:
with ExitStack() as stack:
model = None
with fx_traceback.preserve_node_meta():
if with_export:
ep = torch.export.export(SimpleLinear(), inputs)
model = ep.module()
else:
model = SimpleLinear()
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions=decomposition_table
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if (
node.target
in (
torch.ops.prims.transpose.default,
torch.ops.aten.mm.default,
torch.ops.prims.mul.default,
torch.ops.prims.broadcast_in_dim.default,
torch.ops.prims.add.default,
)
# TODO: add annotation to backward graph nodes
and node.meta.get("partitioner_tag") != "is_backward"
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta.get("custom", {}), {})
if __name__ == "__main__":
run_tests()