From 2f1dbfea020e8db7fbeaf0767962055d8ce0779b Mon Sep 17 00:00:00 2001 From: Basil Wong Date: Wed, 13 Nov 2024 23:09:16 +0000 Subject: [PATCH] Logging Refactor - Remove Print Statements (#139782) Summary: Removes print statements and implements logging via the logging library. Hopefully this will allow more control on the level of logging when running models. Test Plan: ``` AOT_PARTITIONER_DEBUG=1 buck2 run @mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2 ``` Resulting output paste: P1674535630 * Full logs paste: P1674535621 ``` pastry P1674535621 | grep "functorch/partitioners.py" | pastry ``` Logging results: P1674549514 Differential Revision: D61678215 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139782 Approved by: https://github.com/paryxyt, https://github.com/jansel --- torch/_functorch/partitioners.py | 48 ++++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index e36a02853c23..3720900763cc 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -37,8 +37,8 @@ if TYPE_CHECKING: import sympy -AOT_PARTITIONER_DEBUG = config.debug_partitioner -log = logging.getLogger(__name__) +AOT_PARTITIONER_DEBUG: bool = config.debug_partitioner +log: logging.Logger = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims @@ -510,7 +510,7 @@ def _count_ops(graph: fx.Graph): for node in graph.nodes: if node.op == "call_function": cnt[node.target.__name__] += 1 - print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) + log.info("%s", sorted(cnt.items(), key=lambda x: x[1], reverse=True)) @functools.lru_cache(None) @@ -824,8 +824,7 @@ def solve_min_cut( if node.op == "call_function" and hasattr(node.target, "_overloadpacket") } ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} - print("Ops banned from re-materialization: ", ops_ignored) - print() + log.info("Ops banned from re-materialization: %s", ops_ignored) def can_fuse_into_auto_functionalized(a, b): if b.target != torch.ops.higher_order.auto_functionalized: @@ -921,7 +920,7 @@ def solve_min_cut( if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( node ): - log.info("materialized backwards: %s %s", node, tuple(node.users)) + log.debug("materialized backwards: %s %s", node, tuple(node.users)) return True # Arbitrary hack that sometimes seems to help things. The above @@ -1171,8 +1170,8 @@ def solve_min_cut( try: cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") except Exception: - print("Failed to compute min-cut on following graph:") - print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + log.info("Failed to compute min-cut on following graph:") + log.info("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) visualize_min_cut_graph(nx_graph) raise @@ -1209,7 +1208,7 @@ def visualize_min_cut_graph(nx_graph): # Color edges with weight 'inf' as red if weight == float("inf"): edge.set_color("red") - print("Visualizing the failed graph to min_cut_failed.svg") + log.info("Visualizing the failed graph to min_cut_failed.svg") dot_graph.write_svg("min_cut_failed.svg") @@ -1897,7 +1896,6 @@ def min_cut_rematerialization_partition( if isinstance(node.meta.get("memory_budget", None), float): memory_budget = node.meta["memory_budget"] break - # print("Memory Budget: ", memory_budget) saved_values = choose_saved_values_set( joint_graph, node_info, @@ -1923,14 +1921,15 @@ def min_cut_rematerialization_partition( bw_module = reordering_to_mimic_autograd_engine(bw_module) if AOT_PARTITIONER_DEBUG: - from torch._inductor.fx_utils import get_node_storage - - storages = {get_node_storage(node) for node in saved_values} - print( - "Theoretical Activations Stored: ", - sum(_size_of(i) for i in saved_values) / 1e9, - ) + # Calculate sorted sizes of saved values sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) + + # Log total theoretical activations stored + total_activations_size_gb = sum(_size_of(i) for i in saved_values) / 1e9 + log.debug("Theoretical Activations Stored: %.2f GB", total_activations_size_gb) + + # Log theoretical per activation storage sizes + log.debug("Theoretical Per Activation Storage Sizes: %s", sorted_sizes) fw_module_nodes = { node.name for node in fw_module.graph.nodes if node.op == "call_function" } @@ -1943,13 +1942,14 @@ def min_cut_rematerialization_partition( for node in fw_module.graph.nodes: if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): counts[str(node.target._overloadpacket)] += 1 - print( - f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}" - ) - print( - "Count of Ops Rematerialized: ", - sorted(counts.items(), key=lambda x: x[1], reverse=True), + log.debug( + "# remat/fw/bw: %d/%d/%d", + len(remat_nodes), + len(fw_module_nodes), + len(bw_module_nodes), ) + rematerialized_ops = sorted(counts.items(), key=lambda x: x[1], reverse=True) + log.debug("Count of Ops Rematerialized: %s", rematerialized_ops) return fw_module, bw_module @@ -1970,7 +1970,7 @@ def draw_graph( base, ext = os.path.splitext(fname) if not ext: ext = "." + config.torch_compile_graph_format - print(f"Writing FX graph to file: {base}{ext}") + log.info("Writing FX graph to file: %s%s", base, ext) g = graph_drawer.FxGraphDrawer( traced, figname,