Compare commits

...

1 Commits

Author SHA1 Message Date
e1ddcc4545 [annotation] add logging for debugging annotation 2025-10-17 16:34:37 -07:00
4 changed files with 20 additions and 0 deletions

View File

@ -982,6 +982,7 @@ exclusions = {
"graph_region_expansion",
"hierarchical_compile",
"compute_dependencies",
"annotation",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@ -4,6 +4,7 @@ Contains various utils for AOTAutograd, including those for handling collections
"""
import dataclasses
import logging
import operator
import warnings
from collections.abc import Callable
@ -40,6 +41,7 @@ KNOWN_TYPES = [
original_zip = zip
aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects")
annotation_log = getArtifactLogger(__name__, "annotation")
def strict_zip(*iterables, strict=True, **kwargs):
@ -443,6 +445,10 @@ def _copy_metadata_to_bw_nodes_in_subgraph(
) -> None:
"""Copy metadata from forward nodes to backward nodes in a single subgraph."""
for node in fx_g.graph.nodes:
annotation_log.debug("node: %s", node.name)
seq_nr = node.meta.get("seq_nr")
annotation_log.debug("seq_nr: %s", seq_nr)
if not _is_backward_node_with_seq_nr(node):
continue
@ -478,6 +484,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
if isinstance(submod, torch.fx.GraphModule):
_collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node)
if annotation_log.isEnabledFor(logging.DEBUG):
for k, v in fwd_seq_nr_to_node.items():
annotation_log.debug("forward:: key: %s, value: %s", k, v)
# Second pass: copy metadata to backward nodes in all subgraphs
# using the global forward mapping
for submod in fx_g.modules():

View File

@ -246,4 +246,9 @@ register_artifact(
"Logs debug info for hierarchical compilation",
off_by_default=True,
)
register_artifact(
"annotation",
"Logs detailed steps of the creating annotation on graph nodes",
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="")

View File

@ -17,6 +17,7 @@ from typing import Any, Optional
import torch
import torch.fx.traceback as fx_traceback
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
from torch._logging import getArtifactLogger
from torch.utils._traceback import CapturedTraceback
from ._compatibility import compatibility
@ -40,6 +41,7 @@ __all__ = [
log = logging.getLogger(__name__)
annotation_log = getArtifactLogger(__name__, "annotation")
@compatibility(is_backward_compatible=False)
@ -202,7 +204,9 @@ class TracerBase:
# BWD pass we retrieve the sequence_nr stored on the current
# executing autograd Node. See NOTE [ Sequence Number ].
if current_meta.get("in_grad_fn", 0) > 0:
annotation_log.debug("seq_nr from current_meta")
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name)
node.meta["seq_nr"] = new_seq_nr
elif self.module_stack: