mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torchfuzz] keep track of operator stats (#164334)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164334 Approved by: https://github.com/pianpwk ghstack dependencies: #164034, #164209, #164211, #164210, #164397, #164284
This commit is contained in:
committed by
PyTorch MergeBot
parent
0fbe3f19c7
commit
39b31a6bfd
@ -122,6 +122,43 @@ def fuzz_and_execute(
|
||||
operation_graph = fuzz_operation_graph(
|
||||
target_spec, max_depth=max_depth, seed=seed, template=template
|
||||
)
|
||||
|
||||
# Extract and print operation statistics
|
||||
operation_counts = {}
|
||||
for node in operation_graph.nodes.values():
|
||||
# Use the fully qualified torch operation name if available
|
||||
from torchfuzz.operators import get_operator
|
||||
|
||||
# Try to get the fully qualified torch operation name
|
||||
torch_op_name = None
|
||||
|
||||
# Extract the base operation name (without arg_X suffixes)
|
||||
base_op_name = node.op_name
|
||||
if node.op_name.startswith("arg_"):
|
||||
# For arg operations, use just "arg" to look up in registry
|
||||
base_op_name = "arg"
|
||||
|
||||
try:
|
||||
operator = get_operator(base_op_name)
|
||||
if (
|
||||
operator
|
||||
and hasattr(operator, "torch_op_name")
|
||||
and operator.torch_op_name
|
||||
):
|
||||
torch_op_name = operator.torch_op_name
|
||||
except (KeyError, ValueError):
|
||||
# If the operator doesn't exist in registry, use the node's op_name
|
||||
pass
|
||||
|
||||
# Use fully qualified name if available, otherwise use the node's op_name
|
||||
display_name = torch_op_name if torch_op_name else node.op_name
|
||||
operation_counts[display_name] = operation_counts.get(display_name, 0) + 1
|
||||
|
||||
# Print operation statistics in a parseable format
|
||||
print("OPERATION_STATS:")
|
||||
for op_name, count in sorted(operation_counts.items()):
|
||||
print(f" {op_name}: {count}")
|
||||
|
||||
logger.debug("⏱️ Step 3: Converting to Python code...")
|
||||
start_time = time.time()
|
||||
python_code = convert_graph_to_python_code(
|
||||
|
@ -8,6 +8,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@ -19,11 +20,19 @@ try:
|
||||
except ImportError:
|
||||
HAS_TQDM = False
|
||||
|
||||
# Create a mock tqdm class for type safety
|
||||
class MockTqdm:
|
||||
@staticmethod
|
||||
def write(msg, file=None):
|
||||
print(msg, file=file, flush=True)
|
||||
|
||||
tqdm = MockTqdm()
|
||||
|
||||
|
||||
def persist_print(msg):
|
||||
"""Print messages that persist with tqdm progress bars."""
|
||||
try:
|
||||
if HAS_TQDM:
|
||||
if HAS_TQDM and hasattr(tqdm, "write"):
|
||||
# Keep prints on the same stream as the bar
|
||||
tqdm.write(msg, file=sys.stderr)
|
||||
else:
|
||||
@ -62,6 +71,7 @@ class FuzzerResult:
|
||||
output: str
|
||||
duration: float
|
||||
ignored_pattern_idx: int
|
||||
operation_stats: dict[str, int] # New field for operation statistics
|
||||
|
||||
|
||||
def is_ignored_output(output: str) -> int:
|
||||
@ -123,23 +133,51 @@ def run_fuzzer_with_seed(seed: int, template: str = "default") -> FuzzerResult:
|
||||
output += f"STDERR:\n{result.stderr}\n"
|
||||
output += f"Return code: {result.returncode}"
|
||||
|
||||
# Parse operation statistics from the output
|
||||
operation_stats = {}
|
||||
if result.stdout:
|
||||
lines = result.stdout.split("\n")
|
||||
in_stats_section = False
|
||||
for line in lines:
|
||||
if line.strip() == "OPERATION_STATS:":
|
||||
in_stats_section = True
|
||||
continue
|
||||
elif in_stats_section:
|
||||
if line.startswith(" ") and ":" in line:
|
||||
# Parse line like " torch.add: 3"
|
||||
op_line = line.strip()
|
||||
if ": " in op_line:
|
||||
op_name, count_str = op_line.split(": ", 1)
|
||||
try:
|
||||
count = int(count_str)
|
||||
operation_stats[op_name] = count
|
||||
except ValueError:
|
||||
pass # Skip malformed lines
|
||||
else:
|
||||
# End of stats section
|
||||
in_stats_section = False
|
||||
|
||||
# Check if output should be ignored and which pattern matched
|
||||
ignored_pattern_idx = is_ignored_output(output)
|
||||
if ignored_pattern_idx != -1:
|
||||
# Mark as ignored (could also return a special flag if needed)
|
||||
output = "[IGNORED] " + output
|
||||
|
||||
return FuzzerResult(seed, success, output, duration, ignored_pattern_idx)
|
||||
return FuzzerResult(
|
||||
seed, success, output, duration, ignored_pattern_idx, operation_stats
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
duration = time.time() - start_time
|
||||
return FuzzerResult(
|
||||
seed, False, "Process timed out after 300 seconds", duration, -1
|
||||
seed, False, "Process timed out after 300 seconds", duration, -1, {}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
return FuzzerResult(seed, False, f"Exception occurred: {str(e)}", duration, -1)
|
||||
return FuzzerResult(
|
||||
seed, False, f"Exception occurred: {str(e)}", duration, -1, {}
|
||||
)
|
||||
|
||||
|
||||
def print_output_lines(output: str, write_func):
|
||||
@ -217,6 +255,8 @@ def run_multi_process_fuzzer(
|
||||
|
||||
# Set up progress bar
|
||||
if HAS_TQDM:
|
||||
from tqdm import tqdm # Import the real tqdm here
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(seeds),
|
||||
desc="Processing seeds",
|
||||
@ -315,7 +355,9 @@ def run_multi_process_fuzzer(
|
||||
)
|
||||
persist_print(f"❌ POOL ERROR - Seed {seeds[i]}: {str(e)}")
|
||||
results.append(
|
||||
FuzzerResult(seeds[i], False, f"Pool error: {str(e)}", 0.0, -1)
|
||||
FuzzerResult(
|
||||
seeds[i], False, f"Pool error: {str(e)}", 0.0, -1, {}
|
||||
)
|
||||
)
|
||||
|
||||
# Close progress bar
|
||||
@ -367,6 +409,9 @@ def run_multi_process_fuzzer(
|
||||
f" Pattern {idx}: {pattern.pattern!r} - {count} ({percent:.1f}%)"
|
||||
)
|
||||
|
||||
# Aggregate and print operation distribution
|
||||
_print_operation_distribution(results)
|
||||
|
||||
sys.exit(130)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
@ -415,3 +460,38 @@ def run_multi_process_fuzzer(
|
||||
persist_print(
|
||||
f" Pattern {idx}: {pattern.pattern!r} - {count} ({percent:.1f}%)"
|
||||
)
|
||||
|
||||
# Aggregate and print operation distribution
|
||||
_print_operation_distribution(results)
|
||||
|
||||
|
||||
def _print_operation_distribution(results: list[FuzzerResult]) -> None:
|
||||
"""Helper function to print operation distribution statistics."""
|
||||
total_operation_stats = defaultdict(int)
|
||||
total_operations = 0
|
||||
|
||||
# Collect operation stats from all successful results
|
||||
for result in results:
|
||||
if result.success and result.operation_stats:
|
||||
for op_name, count in result.operation_stats.items():
|
||||
total_operation_stats[op_name] += count
|
||||
total_operations += count
|
||||
|
||||
if total_operation_stats:
|
||||
persist_print("\n📊 OPERATION DISTRIBUTION")
|
||||
persist_print("=" * 60)
|
||||
persist_print(f"Total operations executed: {total_operations}")
|
||||
persist_print("")
|
||||
|
||||
# Sort operations by count (descending) for better readability
|
||||
sorted_ops = sorted(
|
||||
total_operation_stats.items(), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
|
||||
for op_name, count in sorted_ops:
|
||||
percentage = (count / total_operations * 100) if total_operations > 0 else 0
|
||||
persist_print(f" {op_name:<30} {count:>6} times ({percentage:>5.1f}%)")
|
||||
else:
|
||||
persist_print(
|
||||
"\n📊 No operation statistics collected (no successful runs with stats)"
|
||||
)
|
||||
|
Reference in New Issue
Block a user