Compare commits

...

1 Commits

Author SHA1 Message Date
89fb2567e7 Add annotation to assertion nodes 2025-11-05 17:53:13 -08:00
2 changed files with 34 additions and 0 deletions

View File

@ -721,6 +721,34 @@ class TestExport(TestCase):
)
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
def test_annotate_on_assert(self):
# nodes added in `apply_runtime_assertion_pass` will be annotated
class M(torch.nn.Module):
def forward(self, x, y):
with torch.fx.traceback.annotate({"moo": 0}):
x = torch.cat([x, x])
b = y.item()
torch._check(b >= x.shape[0])
return x * b
with torch.fx.traceback.preserve_node_meta():
ep = torch.export.export(
M(),
(torch.randn(3), torch.tensor(6)),
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
)
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
)
@requires_gpu
def test_flex_attention_export(self):
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

View File

@ -165,6 +165,7 @@ def insert_deferred_runtime_asserts(
node: torch.fx.Node,
stack_trace: Optional[str] = None,
nn_module_stack: Optional[dict[str, Any]] = None,
custom: Optional[dict[str, Any]] = None,
) -> None:
fake_args = pytree.tree_map(
lambda arg: (
@ -188,6 +189,8 @@ def insert_deferred_runtime_asserts(
node.meta["stack_trace"] = stack_trace
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack
if custom is not None:
node.meta["custom"] = custom
# Track asserts/checks we've added
added_asserts: set[sympy.Expr] = set()
@ -617,6 +620,9 @@ def insert_deferred_runtime_asserts(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
nn_module_stack=node.meta.get("nn_module_stack"),
# nodes added in `apply_runtime_assertion_pass` will have the same annotation
# as the input node to the assertion
custom=node.meta.get("custom"),
),
):
if (min_val := convert(vr.lower)) is not None: