Logging Refactor - Remove Print Statements (#139782)

Summary:
Removes print statements and implements logging via the logging library.

Hopefully this will allow more control on the level of logging when running models.

Test Plan:
```
AOT_PARTITIONER_DEBUG=1 buck2 run @mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2
```

Resulting output paste: P1674535630
* Full logs paste: P1674535621

```
pastry P1674535621 | grep "functorch/partitioners.py" | pastry
```

Logging results: P1674549514

Differential Revision: D61678215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139782
Approved by: https://github.com/paryxyt, https://github.com/jansel
This commit is contained in:
Basil Wong
2024-11-13 23:09:16 +00:00
committed by PyTorch MergeBot
parent b34bb1f562
commit 2f1dbfea02

View File

@ -37,8 +37,8 @@ if TYPE_CHECKING:
import sympy
AOT_PARTITIONER_DEBUG = config.debug_partitioner
log = logging.getLogger(__name__)
AOT_PARTITIONER_DEBUG: bool = config.debug_partitioner
log: logging.Logger = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims
@ -510,7 +510,7 @@ def _count_ops(graph: fx.Graph):
for node in graph.nodes:
if node.op == "call_function":
cnt[node.target.__name__] += 1
print(sorted(cnt.items(), key=lambda x: x[1], reverse=True))
log.info("%s", sorted(cnt.items(), key=lambda x: x[1], reverse=True))
@functools.lru_cache(None)
@ -824,8 +824,7 @@ def solve_min_cut(
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
}
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
print("Ops banned from re-materialization: ", ops_ignored)
print()
log.info("Ops banned from re-materialization: %s", ops_ignored)
def can_fuse_into_auto_functionalized(a, b):
if b.target != torch.ops.higher_order.auto_functionalized:
@ -921,7 +920,7 @@ def solve_min_cut(
if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(
node
):
log.info("materialized backwards: %s %s", node, tuple(node.users))
log.debug("materialized backwards: %s %s", node, tuple(node.users))
return True
# Arbitrary hack that sometimes seems to help things. The above
@ -1171,8 +1170,8 @@ def solve_min_cut(
try:
cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
except Exception:
print("Failed to compute min-cut on following graph:")
print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph)))
log.info("Failed to compute min-cut on following graph:")
log.info("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph)))
visualize_min_cut_graph(nx_graph)
raise
@ -1209,7 +1208,7 @@ def visualize_min_cut_graph(nx_graph):
# Color edges with weight 'inf' as red
if weight == float("inf"):
edge.set_color("red")
print("Visualizing the failed graph to min_cut_failed.svg")
log.info("Visualizing the failed graph to min_cut_failed.svg")
dot_graph.write_svg("min_cut_failed.svg")
@ -1897,7 +1896,6 @@ def min_cut_rematerialization_partition(
if isinstance(node.meta.get("memory_budget", None), float):
memory_budget = node.meta["memory_budget"]
break
# print("Memory Budget: ", memory_budget)
saved_values = choose_saved_values_set(
joint_graph,
node_info,
@ -1923,14 +1921,15 @@ def min_cut_rematerialization_partition(
bw_module = reordering_to_mimic_autograd_engine(bw_module)
if AOT_PARTITIONER_DEBUG:
from torch._inductor.fx_utils import get_node_storage
storages = {get_node_storage(node) for node in saved_values}
print(
"Theoretical Activations Stored: ",
sum(_size_of(i) for i in saved_values) / 1e9,
)
# Calculate sorted sizes of saved values
sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
# Log total theoretical activations stored
total_activations_size_gb = sum(_size_of(i) for i in saved_values) / 1e9
log.debug("Theoretical Activations Stored: %.2f GB", total_activations_size_gb)
# Log theoretical per activation storage sizes
log.debug("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
fw_module_nodes = {
node.name for node in fw_module.graph.nodes if node.op == "call_function"
}
@ -1943,13 +1942,14 @@ def min_cut_rematerialization_partition(
for node in fw_module.graph.nodes:
if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
counts[str(node.target._overloadpacket)] += 1
print(
f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}"
)
print(
"Count of Ops Rematerialized: ",
sorted(counts.items(), key=lambda x: x[1], reverse=True),
log.debug(
"# remat/fw/bw: %d/%d/%d",
len(remat_nodes),
len(fw_module_nodes),
len(bw_module_nodes),
)
rematerialized_ops = sorted(counts.items(), key=lambda x: x[1], reverse=True)
log.debug("Count of Ops Rematerialized: %s", rematerialized_ops)
return fw_module, bw_module
@ -1970,7 +1970,7 @@ def draw_graph(
base, ext = os.path.splitext(fname)
if not ext:
ext = "." + config.torch_compile_graph_format
print(f"Writing FX graph to file: {base}{ext}")
log.info("Writing FX graph to file: %s%s", base, ext)
g = graph_drawer.FxGraphDrawer(
traced,
figname,