From 20295c017ee66977b710b74eb7b0ca0136c746f1 Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Tue, 25 Feb 2025 03:36:06 +0000 Subject: [PATCH] Fix import of getArtifactLogger for ir_pre_fusion and ir_post_fusion (#147560) Fixes #147002 There was an issue with the previous PR https://github.com/pytorch/pytorch/pull/147248 that didn't show up in CI, where a logging import was not complete in torch/_inductor/debug.py before importing it. This only happened if someone directly imported the file without doing any other imports before. Also set to off_by_default by request to reduce log spew. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147560 Approved by: https://github.com/Skylion007 --- torch/_inductor/debug.py | 11 +++++++---- torch/_logging/_registrations.py | 2 ++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index cac2a5b16744..b6a4f7bdca92 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -23,6 +23,7 @@ from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_co from torch import fx as fx from torch._dynamo.repro.after_aot import save_graph_repro from torch._dynamo.utils import get_debug_dir +from torch._logging import getArtifactLogger from torch.fx.graph_module import GraphModule from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.fx.passes.tools_common import legalize_graph @@ -43,8 +44,8 @@ from .virtualized import V log = logging.getLogger(__name__) -ir_pre_fusion_log = torch._logging.getArtifactLogger(__name__, "ir_pre_fusion") -ir_post_fusion_log = torch._logging.getArtifactLogger(__name__, "ir_post_fusion") +ir_pre_fusion_log = getArtifactLogger(__name__, "ir_pre_fusion") +ir_post_fusion_log = getArtifactLogger(__name__, "ir_post_fusion") SchedulerNodeList = list[Any] BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] @@ -525,10 +526,12 @@ class DebugFormatter: fd.write(gm.print_readable(print_output=False)) def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None: - ir_pre_fusion_log.debug("BEFORE FUSION\n%s", self._write_ir(nodes)) + if ir_pre_fusion_log.isEnabledFor(logging.INFO): + ir_pre_fusion_log.info("BEFORE FUSION\n%s", self._write_ir(nodes)) def ir_post_fusion(self, nodes: SchedulerNodeList) -> None: - ir_post_fusion_log.debug("AFTER FUSION\n%s", self._write_ir(nodes)) + if ir_post_fusion_log.isEnabledFor(logging.INFO): + ir_post_fusion_log.info("AFTER FUSION\n%s", self._write_ir(nodes)) def _write_ir(self, nodes: SchedulerNodeList) -> str: buf = io.StringIO() diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index aa8bd2032b99..3d805b639ab9 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -109,10 +109,12 @@ register_artifact( register_artifact( "ir_pre_fusion", "Prints the IR before inductor fusion passes.", + off_by_default=True, ) register_artifact( "ir_post_fusion", "Prints the IR after inductor fusion passes.", + off_by_default=True, ) register_artifact( "compiled_autograd",