[torchfuzz] refactor multi_process_fuzzer to be more readable (#163698)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163698
Approved by: https://github.com/pianpwk
ghstack dependencies: #163547, #163553, #163554, #163555, #163556, #163557, #163558, #163560
This commit is contained in:
bobrenjc93
2025-09-23 15:27:23 -07:00
committed by PyTorch MergeBot
parent 754c7e2e88
commit d927e55498

View File

@ -8,6 +8,7 @@ import re
import subprocess
import sys
import time
from dataclasses import dataclass
try:
@ -53,6 +54,15 @@ IGNORE_PATTERNS: list[re.Pattern] = [
]
@dataclass
class FuzzerResult:
seed: int
success: bool
output: str
duration: float
ignored_pattern_idx: int
def is_ignored_output(output: str) -> int:
"""
Check if the output matches any ignore pattern.
@ -69,7 +79,7 @@ def is_ignored_output(output: str) -> int:
return -1
def run_fuzzer_with_seed(seed: int) -> tuple[int, bool, str, float, int]:
def run_fuzzer_with_seed(seed: int) -> FuzzerResult:
"""
Run fuzzer.py with a specific seed.
@ -77,8 +87,7 @@ def run_fuzzer_with_seed(seed: int) -> tuple[int, bool, str, float, int]:
seed: The seed value to pass to fuzzer.py
Returns:
Tuple of (seed, success, output, duration, ignored_pattern_idx)
ignored_pattern_idx: -1 if not ignored, otherwise index of IGNORE_PATTERNS
FuzzerResult dataclass instance
"""
start_time = time.time()
@ -110,15 +119,44 @@ def run_fuzzer_with_seed(seed: int) -> tuple[int, bool, str, float, int]:
# Mark as ignored (could also return a special flag if needed)
output = "[IGNORED] " + output
return seed, success, output, duration, ignored_pattern_idx
return FuzzerResult(seed, success, output, duration, ignored_pattern_idx)
except subprocess.TimeoutExpired:
duration = time.time() - start_time
return seed, False, "Process timed out after 300 seconds", duration, -1
return FuzzerResult(
seed, False, "Process timed out after 300 seconds", duration, -1
)
except Exception as e:
duration = time.time() - start_time
return seed, False, f"Exception occurred: {str(e)}", duration, -1
return FuzzerResult(seed, False, f"Exception occurred: {str(e)}", duration, -1)
def print_output_lines(output: str, write_func):
"""Helper to print non-empty lines of output using the provided write_func."""
for line in output.split("\n"):
if line.strip():
write_func(f" {line}")
if hasattr(write_func, "__self__") and hasattr(write_func.__self__, "write"):
# For tqdm.write, add an empty line for separation
write_func("")
def handle_result_output(
*,
label: str,
seed: int,
duration: float,
output: str,
ignored: bool,
verbose: bool,
write_func,
):
"""Unified handler for result output, reducing code repetition."""
ignored_text = " [IGNORED]" if ignored else ""
write_func(f"{label} - Seed {seed} (duration: {duration:.2f}s){ignored_text}")
if output.strip() or label.startswith("") or verbose:
print_output_lines(output, write_func)
def run_multi_process_fuzzer(
@ -146,7 +184,7 @@ def run_multi_process_fuzzer(
persist_print("=" * 60)
start_time = time.time()
results = []
results: list[FuzzerResult] = []
successful_count = 0
failed_count = 0
ignored_count = 0
@ -176,27 +214,29 @@ def run_multi_process_fuzzer(
pbar.set_postfix_str(
f"{successful_count}/{failed_count}/{ignored_count} | throughput: 0.00 seeds/hr"
)
def write_func(msg):
pbar.write(msg)
else:
persist_print("Progress: (install tqdm for better progress bar)")
pbar = None
write_func = persist_print
# Collect results as they complete
for i, future in enumerate(future_results):
try:
seed, success, output, duration, ignored_pattern_idx = future.get()
results.append(
(seed, success, output, duration, ignored_pattern_idx)
)
result: FuzzerResult = future.get()
results.append(result)
if ignored_pattern_idx != -1:
ignored_seeds.append(seed)
ignored_pattern_counts[ignored_pattern_idx] += 1
if result.ignored_pattern_idx != -1:
ignored_seeds.append(result.seed)
ignored_pattern_counts[result.ignored_pattern_idx] += 1
ignored_count += 1
# Only increment failed_count if not ignored
if success:
if result.success:
successful_count += 1
elif ignored_pattern_idx == -1:
elif result.ignored_pattern_idx == -1:
failed_count += 1
elapsed = time.time() - start_time
@ -207,77 +247,48 @@ def run_multi_process_fuzzer(
pbar.set_postfix_str(
f"{successful_count}/{failed_count}/{ignored_count} | throughput: {throughput:.2f} seeds/hr"
)
# tqdm automatically shows ETA (estimated time remaining) in the bar_format above
pbar.update(1)
else:
status_emoji = "" if success else ""
ignored_text = " (IGNORED)" if ignored_pattern_idx != -1 else ""
status_emoji = "" if result.success else ""
ignored_text = (
" (IGNORED)" if result.ignored_pattern_idx != -1 else ""
)
persist_print(
f"Completed {i + 1}/{len(seeds)} - Seed {seed}: {status_emoji}{ignored_text}"
f"Completed {i + 1}/{len(seeds)} - Seed {result.seed}: {status_emoji}{ignored_text}"
)
# Only show detailed output for failures (unless verbose)
if not success and ignored_pattern_idx == -1:
if HAS_TQDM and pbar:
pbar.write(
f"❌ FAILURE - Seed {seed} (duration: {duration:.2f}s):"
)
for line in output.split("\n"):
if line.strip():
pbar.write(f" {line}")
pbar.write("") # Empty line
else:
persist_print(
f"❌ FAILURE - Seed {seed} (duration: {duration:.2f}s):"
)
for line in output.split("\n"):
if line.strip():
persist_print(f" {line}")
persist_print("")
elif not success and ignored_pattern_idx != -1:
# Optionally, print ignored failures if desired
# Unified output handling
if not result.success and result.ignored_pattern_idx == -1:
handle_result_output(
label="❌ FAILURE",
seed=result.seed,
duration=result.duration,
output=result.output,
ignored=False,
verbose=verbose,
write_func=write_func,
)
elif not result.success and result.ignored_pattern_idx != -1:
if verbose:
if HAS_TQDM and pbar:
pbar.write(
f"🚫 IGNORED - Seed {seed} (duration: {duration:.2f}s):"
)
for line in output.split("\n"):
if line.strip():
pbar.write(f" {line}")
pbar.write("")
else:
persist_print(
f"🚫 IGNORED - Seed {seed} (duration: {duration:.2f}s):"
)
for line in output.split("\n"):
if line.strip():
persist_print(f" {line}")
persist_print("")
handle_result_output(
label="🚫 IGNORED",
seed=result.seed,
duration=result.duration,
output=result.output,
ignored=True,
verbose=verbose,
write_func=write_func,
)
elif verbose:
if HAS_TQDM and pbar:
ignored_text = (
" [IGNORED]" if ignored_pattern_idx != -1 else ""
)
pbar.write(
f"✅ SUCCESS - Seed {seed} (duration: {duration:.2f}s){ignored_text}"
)
if output.strip():
for line in output.split("\n"):
if line.strip():
pbar.write(f" {line}")
pbar.write("")
else:
ignored_text = (
" [IGNORED]" if ignored_pattern_idx != -1 else ""
)
persist_print(
f"✅ SUCCESS - Seed {seed} (duration: {duration:.2f}s){ignored_text}"
)
if output.strip():
for line in output.split("\n"):
if line.strip():
persist_print(f" {line}")
persist_print("")
handle_result_output(
label="✅ SUCCESS",
seed=result.seed,
duration=result.duration,
output=result.output,
ignored=(result.ignored_pattern_idx != -1),
verbose=verbose,
write_func=write_func,
)
except Exception as e:
failed_count += 1
@ -290,7 +301,9 @@ def run_multi_process_fuzzer(
f"Completed {i + 1}/{len(seeds)} - Seed {seeds[i]}: ❌ POOL ERROR"
)
persist_print(f"❌ POOL ERROR - Seed {seeds[i]}: {str(e)}")
results.append((seeds[i], False, f"Pool error: {str(e)}", 0.0, -1))
results.append(
FuzzerResult(seeds[i], False, f"Pool error: {str(e)}", 0.0, -1)
)
# Close progress bar
if HAS_TQDM and pbar:
@ -303,10 +316,12 @@ def run_multi_process_fuzzer(
persist_print("📈 SUMMARY (partial, interrupted)")
persist_print("=" * 60)
successful = [r for r in results if r[1]]
successful = [res for res in results if res.success]
# Only count as failed if not ignored
failed = [r for r in results if not r[1] and r[4] == -1]
ignored = [r for r in results if r[4] != -1]
failed = [
res for res in results if not res.success and res.ignored_pattern_idx == -1
]
ignored = [res for res in results if res.ignored_pattern_idx != -1]
persist_print(
f"✅ Successful: {len(successful)}/{len(results)} ({(len(successful) / len(results) * 100 if results else 0):.1f}%)"
@ -322,13 +337,13 @@ def run_multi_process_fuzzer(
else "⚡ Throughput: N/A"
)
if failed:
persist_print(f"\n❌ Failed seeds: {[r[0] for r in failed]}")
persist_print(f"\n❌ Failed seeds: {[res.seed for res in failed]}")
if successful:
persist_print(f"✅ Successful seeds: {[r[0] for r in successful]}")
avg_success_time = sum(r[3] for r in successful) / len(successful)
persist_print(f"✅ Successful seeds: {[res.seed for res in successful]}")
avg_success_time = sum(res.duration for res in successful) / len(successful)
persist_print(f"⚡ Avg time for successful runs: {avg_success_time:.2f}s")
if ignored:
persist_print(f"\n🚫 Ignored seeds: {[r[0] for r in ignored]}")
persist_print(f"\n🚫 Ignored seeds: {[res.seed for res in ignored]}")
# Print ignore pattern stats
persist_print("\n🚫 Ignored pattern statistics:")
total_ignored = len(ignored)
@ -348,10 +363,12 @@ def run_multi_process_fuzzer(
persist_print("📈 SUMMARY")
persist_print("=" * 60)
successful = [r for r in results if r[1]]
successful = [res for res in results if res.success]
# Only count as failed if not ignored
failed = [r for r in results if not r[1] and r[4] == -1]
ignored = [r for r in results if r[4] != -1]
failed = [
res for res in results if not res.success and res.ignored_pattern_idx == -1
]
ignored = [res for res in results if res.ignored_pattern_idx != -1]
persist_print(
f"✅ Successful: {len(successful)}/{len(results)} ({len(successful) / len(results) * 100:.1f}%)"
@ -367,15 +384,15 @@ def run_multi_process_fuzzer(
)
if failed:
persist_print(f"\n❌ Failed seeds: {[r[0] for r in failed]}")
persist_print(f"\n❌ Failed seeds: {[res.seed for res in failed]}")
if successful:
persist_print(f"✅ Successful seeds: {[r[0] for r in successful]}")
avg_success_time = sum(r[3] for r in successful) / len(successful)
persist_print(f"✅ Successful seeds: {[res.seed for res in successful]}")
avg_success_time = sum(res.duration for res in successful) / len(successful)
persist_print(f"⚡ Avg time for successful runs: {avg_success_time:.2f}s")
if ignored:
persist_print(f"\n🚫 Ignored seeds: {[r[0] for r in ignored]}")
persist_print(f"\n🚫 Ignored seeds: {[res.seed for res in ignored]}")
# Print ignore pattern stats
persist_print("\n🚫 Ignored pattern statistics:")
total_ignored = len(ignored)