[torchfuzz] keep track of operator stats (#164334)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164334
Approved by: https://github.com/pianpwk
ghstack dependencies: #164034, #164209, #164211, #164210, #164397, #164284
This commit is contained in:
bobrenjc93
2025-10-01 15:59:16 -07:00
committed by PyTorch MergeBot
parent 0fbe3f19c7
commit 39b31a6bfd
2 changed files with 122 additions and 5 deletions

View File

@ -122,6 +122,43 @@ def fuzz_and_execute(
operation_graph = fuzz_operation_graph(
target_spec, max_depth=max_depth, seed=seed, template=template
)
# Extract and print operation statistics
operation_counts = {}
for node in operation_graph.nodes.values():
# Use the fully qualified torch operation name if available
from torchfuzz.operators import get_operator
# Try to get the fully qualified torch operation name
torch_op_name = None
# Extract the base operation name (without arg_X suffixes)
base_op_name = node.op_name
if node.op_name.startswith("arg_"):
# For arg operations, use just "arg" to look up in registry
base_op_name = "arg"
try:
operator = get_operator(base_op_name)
if (
operator
and hasattr(operator, "torch_op_name")
and operator.torch_op_name
):
torch_op_name = operator.torch_op_name
except (KeyError, ValueError):
# If the operator doesn't exist in registry, use the node's op_name
pass
# Use fully qualified name if available, otherwise use the node's op_name
display_name = torch_op_name if torch_op_name else node.op_name
operation_counts[display_name] = operation_counts.get(display_name, 0) + 1
# Print operation statistics in a parseable format
print("OPERATION_STATS:")
for op_name, count in sorted(operation_counts.items()):
print(f" {op_name}: {count}")
logger.debug("⏱️ Step 3: Converting to Python code...")
start_time = time.time()
python_code = convert_graph_to_python_code(

View File

@ -8,6 +8,7 @@ import re
import subprocess
import sys
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
@ -19,11 +20,19 @@ try:
except ImportError:
HAS_TQDM = False
# Create a mock tqdm class for type safety
class MockTqdm:
@staticmethod
def write(msg, file=None):
print(msg, file=file, flush=True)
tqdm = MockTqdm()
def persist_print(msg):
"""Print messages that persist with tqdm progress bars."""
try:
if HAS_TQDM:
if HAS_TQDM and hasattr(tqdm, "write"):
# Keep prints on the same stream as the bar
tqdm.write(msg, file=sys.stderr)
else:
@ -62,6 +71,7 @@ class FuzzerResult:
output: str
duration: float
ignored_pattern_idx: int
operation_stats: dict[str, int] # New field for operation statistics
def is_ignored_output(output: str) -> int:
@ -123,23 +133,51 @@ def run_fuzzer_with_seed(seed: int, template: str = "default") -> FuzzerResult:
output += f"STDERR:\n{result.stderr}\n"
output += f"Return code: {result.returncode}"
# Parse operation statistics from the output
operation_stats = {}
if result.stdout:
lines = result.stdout.split("\n")
in_stats_section = False
for line in lines:
if line.strip() == "OPERATION_STATS:":
in_stats_section = True
continue
elif in_stats_section:
if line.startswith(" ") and ":" in line:
# Parse line like " torch.add: 3"
op_line = line.strip()
if ": " in op_line:
op_name, count_str = op_line.split(": ", 1)
try:
count = int(count_str)
operation_stats[op_name] = count
except ValueError:
pass # Skip malformed lines
else:
# End of stats section
in_stats_section = False
# Check if output should be ignored and which pattern matched
ignored_pattern_idx = is_ignored_output(output)
if ignored_pattern_idx != -1:
# Mark as ignored (could also return a special flag if needed)
output = "[IGNORED] " + output
return FuzzerResult(seed, success, output, duration, ignored_pattern_idx)
return FuzzerResult(
seed, success, output, duration, ignored_pattern_idx, operation_stats
)
except subprocess.TimeoutExpired:
duration = time.time() - start_time
return FuzzerResult(
seed, False, "Process timed out after 300 seconds", duration, -1
seed, False, "Process timed out after 300 seconds", duration, -1, {}
)
except Exception as e:
duration = time.time() - start_time
return FuzzerResult(seed, False, f"Exception occurred: {str(e)}", duration, -1)
return FuzzerResult(
seed, False, f"Exception occurred: {str(e)}", duration, -1, {}
)
def print_output_lines(output: str, write_func):
@ -217,6 +255,8 @@ def run_multi_process_fuzzer(
# Set up progress bar
if HAS_TQDM:
from tqdm import tqdm # Import the real tqdm here
pbar = tqdm(
total=len(seeds),
desc="Processing seeds",
@ -315,7 +355,9 @@ def run_multi_process_fuzzer(
)
persist_print(f"❌ POOL ERROR - Seed {seeds[i]}: {str(e)}")
results.append(
FuzzerResult(seeds[i], False, f"Pool error: {str(e)}", 0.0, -1)
FuzzerResult(
seeds[i], False, f"Pool error: {str(e)}", 0.0, -1, {}
)
)
# Close progress bar
@ -367,6 +409,9 @@ def run_multi_process_fuzzer(
f" Pattern {idx}: {pattern.pattern!r} - {count} ({percent:.1f}%)"
)
# Aggregate and print operation distribution
_print_operation_distribution(results)
sys.exit(130)
total_time = time.time() - start_time
@ -415,3 +460,38 @@ def run_multi_process_fuzzer(
persist_print(
f" Pattern {idx}: {pattern.pattern!r} - {count} ({percent:.1f}%)"
)
# Aggregate and print operation distribution
_print_operation_distribution(results)
def _print_operation_distribution(results: list[FuzzerResult]) -> None:
"""Helper function to print operation distribution statistics."""
total_operation_stats = defaultdict(int)
total_operations = 0
# Collect operation stats from all successful results
for result in results:
if result.success and result.operation_stats:
for op_name, count in result.operation_stats.items():
total_operation_stats[op_name] += count
total_operations += count
if total_operation_stats:
persist_print("\n📊 OPERATION DISTRIBUTION")
persist_print("=" * 60)
persist_print(f"Total operations executed: {total_operations}")
persist_print("")
# Sort operations by count (descending) for better readability
sorted_ops = sorted(
total_operation_stats.items(), key=lambda x: x[1], reverse=True
)
for op_name, count in sorted_ops:
percentage = (count / total_operations * 100) if total_operations > 0 else 0
persist_print(f" {op_name:<30} {count:>6} times ({percentage:>5.1f}%)")
else:
persist_print(
"\n📊 No operation statistics collected (no successful runs with stats)"
)