Files
pytorch/tools/experimental/torchfuzz/visualize_graph.py

228 lines
6.9 KiB
Python

# mypy: ignore-errors
"""
Visualization tools for operation stacks and graphs as DAGs.
"""
import subprocess
from torchfuzz.ops_fuzzer import OperationGraph
from torchfuzz.tensor_fuzzer import TensorSpec
def save_and_render_dot(dot_content: str, filename: str = "operation_stack"):
"""
Save DOT content to file and render as PNG/PDF.
Args:
dot_content: DOT format string
filename: Base filename (without extension)
"""
import os
dot_file = f"{filename}.dot"
png_file = f"{filename}.png"
# Get absolute path for clickable link
abs_png = os.path.abspath(png_file)
# Save DOT file
with open(dot_file, "w") as f:
f.write(dot_content)
# Render to PNG
try:
subprocess.run(["dot", "-Tpng", dot_file, "-o", png_file], check=True)
print(f"🖼️ View: file://{abs_png}")
except (subprocess.CalledProcessError, FileNotFoundError):
pass
def operation_graph_to_dot(
graph: OperationGraph, title: str = "Operation Graph"
) -> str:
"""
Convert an operation graph to Graphviz DOT format for visualization.
Args:
graph: OperationGraph instance
title: Title for the graph
Returns:
DOT format string
"""
dot_lines = [
"digraph OperationGraph {",
f' label="{title}";',
" rankdir=TB;", # Top to bottom layout
" node [shape=box, style=filled, fontsize=10];",
" edge [fontsize=8];",
"",
]
# Add nodes with styling based on operation type
for node_id, node in graph.nodes.items():
# Choose color and shape based on operation type
if node.op_name.startswith("arg_"):
color = "lightblue"
shape = "ellipse"
elif node.op_name == "constant":
color = "lightgreen"
shape = "ellipse"
elif "aten" in node.op_name:
color = "lightyellow"
shape = "box"
else:
color = "lightgray"
shape = "box"
# Create comprehensive label
if node.op_name.startswith("arg_"):
label_parts = [node.op_name]
else:
label_parts = [node_id, node.op_name, f"depth {node.depth}"]
if hasattr(node.output_spec, "dtype"):
dtype_str = str(node.output_spec.dtype).replace("torch.", "")
label_parts.append(dtype_str)
# Only add size for TensorSpec, not ScalarSpec
if isinstance(node.output_spec, TensorSpec) and node.output_spec.size:
size_str = "x".join(map(str, node.output_spec.size))
label_parts.append(f"size {size_str}")
label = "\\n".join(label_parts)
# Special highlighting for root node
extra_style = ""
if node_id == graph.root_node_id:
extra_style = ", penwidth=3, color=red"
dot_lines.append(
f' {node_id} [label="{label}", fillcolor="{color}", shape="{shape}"{extra_style}];'
)
dot_lines.append("")
# Add edges based on the graph structure
for node_id, node in graph.nodes.items():
for i, input_node_id in enumerate(node.input_nodes):
# Add edge from input node to current node with input position label
edge_label = f"input_{i}"
dot_lines.append(
f' {input_node_id} -> {node_id} [label="{edge_label}"];'
)
dot_lines.extend(
[
"",
" // Legend",
" subgraph cluster_legend {",
' label="Legend";',
" style=filled;",
" fillcolor=white;",
' legend_arg [label="arg", fillcolor=lightblue, shape=ellipse];',
' legend_const [label="constant", fillcolor=lightgreen, shape=ellipse];',
' legend_aten [label="aten ops", fillcolor=lightyellow, shape=box];',
' legend_root [label="root", fillcolor=orange, shape=box, penwidth=3, color=red];',
" }",
"}",
]
)
return "\n".join(dot_lines)
def visualize_operation_graph(
graph: OperationGraph,
title: str = "Operation Graph",
output_folder: str = ".",
):
"""
Complete visualization pipeline for an operation graph.
Args:
graph: OperationGraph instance
title: Title for the visualization
output_folder: Folder where to save the visualization files
"""
# Generate DOT content
dot_content = operation_graph_to_dot(graph, title)
# Save and render in the specified folder
import os
filename = os.path.join(output_folder, "operation_graph")
save_and_render_dot(dot_content, filename)
def operation_graph_to_networkx(graph: OperationGraph):
"""
Convert operation graph to NetworkX graph for Python visualization.
Requires: pip install networkx matplotlib
"""
try:
import matplotlib.pyplot as plt
import networkx as nx
except ImportError:
print(
"⚠️ NetworkX/Matplotlib not installed. Run: pip install networkx matplotlib"
)
return
# Create directed graph
G = nx.DiGraph()
# Add nodes
for node_id, node in graph.nodes.items():
label = f"{node_id}\n{node.op_name}\ndepth {node.depth}"
G.add_node(node_id, label=label, node=node)
# Add edges based on the graph structure
for node_id, node in graph.nodes.items():
for input_node_id in node.input_nodes:
if input_node_id in graph.nodes: # Only add edges to nodes in the graph
G.add_edge(input_node_id, node_id)
# Plot
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=2, iterations=50)
# Draw nodes with colors based on operation type
node_colors = []
for node_id in G.nodes():
node = graph.nodes[node_id]
if node.op_name.startswith("arg_"):
node_colors.append("lightblue")
elif node.op_name == "constant":
node_colors.append("lightgreen")
elif "aten" in node.op_name:
node_colors.append("lightyellow")
else:
node_colors.append("lightgray")
# Highlight root node
node_sizes = []
for node_id in G.nodes():
if node_id == graph.root_node_id:
node_sizes.append(2000) # Larger size for root
else:
node_sizes.append(1500)
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes)
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=20)
# Draw labels
labels = {
node_id: f"{node_id}\n{graph.nodes[node_id].op_name}" for node_id in G.nodes()
}
nx.draw_networkx_labels(G, pos, labels, font_size=8)
plt.title("Operation Graph Visualization")
plt.axis("off")
plt.tight_layout()
plt.savefig("operation_graph_networkx.png", dpi=300, bbox_inches="tight")
plt.show()
print("✓ NetworkX graph visualization saved as operation_graph_networkx.png")