mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Logging Refactor - Remove Print Statements (#139782)
Summary: Removes print statements and implements logging via the logging library. Hopefully this will allow more control on the level of logging when running models. Test Plan: ``` AOT_PARTITIONER_DEBUG=1 buck2 run @mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2 ``` Resulting output paste: P1674535630 * Full logs paste: P1674535621 ``` pastry P1674535621 | grep "functorch/partitioners.py" | pastry ``` Logging results: P1674549514 Differential Revision: D61678215 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139782 Approved by: https://github.com/paryxyt, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
b34bb1f562
commit
2f1dbfea02
@ -37,8 +37,8 @@ if TYPE_CHECKING:
|
||||
import sympy
|
||||
|
||||
|
||||
AOT_PARTITIONER_DEBUG = config.debug_partitioner
|
||||
log = logging.getLogger(__name__)
|
||||
AOT_PARTITIONER_DEBUG: bool = config.debug_partitioner
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
@ -510,7 +510,7 @@ def _count_ops(graph: fx.Graph):
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_function":
|
||||
cnt[node.target.__name__] += 1
|
||||
print(sorted(cnt.items(), key=lambda x: x[1], reverse=True))
|
||||
log.info("%s", sorted(cnt.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@ -824,8 +824,7 @@ def solve_min_cut(
|
||||
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
|
||||
}
|
||||
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
|
||||
print("Ops banned from re-materialization: ", ops_ignored)
|
||||
print()
|
||||
log.info("Ops banned from re-materialization: %s", ops_ignored)
|
||||
|
||||
def can_fuse_into_auto_functionalized(a, b):
|
||||
if b.target != torch.ops.higher_order.auto_functionalized:
|
||||
@ -921,7 +920,7 @@ def solve_min_cut(
|
||||
if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(
|
||||
node
|
||||
):
|
||||
log.info("materialized backwards: %s %s", node, tuple(node.users))
|
||||
log.debug("materialized backwards: %s %s", node, tuple(node.users))
|
||||
return True
|
||||
|
||||
# Arbitrary hack that sometimes seems to help things. The above
|
||||
@ -1171,8 +1170,8 @@ def solve_min_cut(
|
||||
try:
|
||||
cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
|
||||
except Exception:
|
||||
print("Failed to compute min-cut on following graph:")
|
||||
print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph)))
|
||||
log.info("Failed to compute min-cut on following graph:")
|
||||
log.info("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph)))
|
||||
visualize_min_cut_graph(nx_graph)
|
||||
raise
|
||||
|
||||
@ -1209,7 +1208,7 @@ def visualize_min_cut_graph(nx_graph):
|
||||
# Color edges with weight 'inf' as red
|
||||
if weight == float("inf"):
|
||||
edge.set_color("red")
|
||||
print("Visualizing the failed graph to min_cut_failed.svg")
|
||||
log.info("Visualizing the failed graph to min_cut_failed.svg")
|
||||
dot_graph.write_svg("min_cut_failed.svg")
|
||||
|
||||
|
||||
@ -1897,7 +1896,6 @@ def min_cut_rematerialization_partition(
|
||||
if isinstance(node.meta.get("memory_budget", None), float):
|
||||
memory_budget = node.meta["memory_budget"]
|
||||
break
|
||||
# print("Memory Budget: ", memory_budget)
|
||||
saved_values = choose_saved_values_set(
|
||||
joint_graph,
|
||||
node_info,
|
||||
@ -1923,14 +1921,15 @@ def min_cut_rematerialization_partition(
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
if AOT_PARTITIONER_DEBUG:
|
||||
from torch._inductor.fx_utils import get_node_storage
|
||||
|
||||
storages = {get_node_storage(node) for node in saved_values}
|
||||
print(
|
||||
"Theoretical Activations Stored: ",
|
||||
sum(_size_of(i) for i in saved_values) / 1e9,
|
||||
)
|
||||
# Calculate sorted sizes of saved values
|
||||
sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
|
||||
|
||||
# Log total theoretical activations stored
|
||||
total_activations_size_gb = sum(_size_of(i) for i in saved_values) / 1e9
|
||||
log.debug("Theoretical Activations Stored: %.2f GB", total_activations_size_gb)
|
||||
|
||||
# Log theoretical per activation storage sizes
|
||||
log.debug("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
|
||||
fw_module_nodes = {
|
||||
node.name for node in fw_module.graph.nodes if node.op == "call_function"
|
||||
}
|
||||
@ -1943,13 +1942,14 @@ def min_cut_rematerialization_partition(
|
||||
for node in fw_module.graph.nodes:
|
||||
if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
|
||||
counts[str(node.target._overloadpacket)] += 1
|
||||
print(
|
||||
f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}"
|
||||
)
|
||||
print(
|
||||
"Count of Ops Rematerialized: ",
|
||||
sorted(counts.items(), key=lambda x: x[1], reverse=True),
|
||||
log.debug(
|
||||
"# remat/fw/bw: %d/%d/%d",
|
||||
len(remat_nodes),
|
||||
len(fw_module_nodes),
|
||||
len(bw_module_nodes),
|
||||
)
|
||||
rematerialized_ops = sorted(counts.items(), key=lambda x: x[1], reverse=True)
|
||||
log.debug("Count of Ops Rematerialized: %s", rematerialized_ops)
|
||||
return fw_module, bw_module
|
||||
|
||||
|
||||
@ -1970,7 +1970,7 @@ def draw_graph(
|
||||
base, ext = os.path.splitext(fname)
|
||||
if not ext:
|
||||
ext = "." + config.torch_compile_graph_format
|
||||
print(f"Writing FX graph to file: {base}{ext}")
|
||||
log.info("Writing FX graph to file: %s%s", base, ext)
|
||||
g = graph_drawer.FxGraphDrawer(
|
||||
traced,
|
||||
figname,
|
||||
|
Reference in New Issue
Block a user