mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] refactor multi_process_fuzzer to be more readable (#163698)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163698 Approved by: https://github.com/pianpwk ghstack dependencies: #163547, #163553, #163554, #163555, #163556, #163557, #163558, #163560
This commit is contained in:
committed by
PyTorch MergeBot
parent
754c7e2e88
commit
d927e55498
@ -8,6 +8,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
try:
|
||||
@ -53,6 +54,15 @@ IGNORE_PATTERNS: list[re.Pattern] = [
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuzzerResult:
|
||||
seed: int
|
||||
success: bool
|
||||
output: str
|
||||
duration: float
|
||||
ignored_pattern_idx: int
|
||||
|
||||
|
||||
def is_ignored_output(output: str) -> int:
|
||||
"""
|
||||
Check if the output matches any ignore pattern.
|
||||
@ -69,7 +79,7 @@ def is_ignored_output(output: str) -> int:
|
||||
return -1
|
||||
|
||||
|
||||
def run_fuzzer_with_seed(seed: int) -> tuple[int, bool, str, float, int]:
|
||||
def run_fuzzer_with_seed(seed: int) -> FuzzerResult:
|
||||
"""
|
||||
Run fuzzer.py with a specific seed.
|
||||
|
||||
@ -77,8 +87,7 @@ def run_fuzzer_with_seed(seed: int) -> tuple[int, bool, str, float, int]:
|
||||
seed: The seed value to pass to fuzzer.py
|
||||
|
||||
Returns:
|
||||
Tuple of (seed, success, output, duration, ignored_pattern_idx)
|
||||
ignored_pattern_idx: -1 if not ignored, otherwise index of IGNORE_PATTERNS
|
||||
FuzzerResult dataclass instance
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
@ -110,15 +119,44 @@ def run_fuzzer_with_seed(seed: int) -> tuple[int, bool, str, float, int]:
|
||||
# Mark as ignored (could also return a special flag if needed)
|
||||
output = "[IGNORED] " + output
|
||||
|
||||
return seed, success, output, duration, ignored_pattern_idx
|
||||
return FuzzerResult(seed, success, output, duration, ignored_pattern_idx)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
duration = time.time() - start_time
|
||||
return seed, False, "Process timed out after 300 seconds", duration, -1
|
||||
return FuzzerResult(
|
||||
seed, False, "Process timed out after 300 seconds", duration, -1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
return seed, False, f"Exception occurred: {str(e)}", duration, -1
|
||||
return FuzzerResult(seed, False, f"Exception occurred: {str(e)}", duration, -1)
|
||||
|
||||
|
||||
def print_output_lines(output: str, write_func):
|
||||
"""Helper to print non-empty lines of output using the provided write_func."""
|
||||
for line in output.split("\n"):
|
||||
if line.strip():
|
||||
write_func(f" {line}")
|
||||
if hasattr(write_func, "__self__") and hasattr(write_func.__self__, "write"):
|
||||
# For tqdm.write, add an empty line for separation
|
||||
write_func("")
|
||||
|
||||
|
||||
def handle_result_output(
|
||||
*,
|
||||
label: str,
|
||||
seed: int,
|
||||
duration: float,
|
||||
output: str,
|
||||
ignored: bool,
|
||||
verbose: bool,
|
||||
write_func,
|
||||
):
|
||||
"""Unified handler for result output, reducing code repetition."""
|
||||
ignored_text = " [IGNORED]" if ignored else ""
|
||||
write_func(f"{label} - Seed {seed} (duration: {duration:.2f}s){ignored_text}")
|
||||
if output.strip() or label.startswith("❌") or verbose:
|
||||
print_output_lines(output, write_func)
|
||||
|
||||
|
||||
def run_multi_process_fuzzer(
|
||||
@ -146,7 +184,7 @@ def run_multi_process_fuzzer(
|
||||
persist_print("=" * 60)
|
||||
|
||||
start_time = time.time()
|
||||
results = []
|
||||
results: list[FuzzerResult] = []
|
||||
successful_count = 0
|
||||
failed_count = 0
|
||||
ignored_count = 0
|
||||
@ -176,27 +214,29 @@ def run_multi_process_fuzzer(
|
||||
pbar.set_postfix_str(
|
||||
f"{successful_count}/{failed_count}/{ignored_count} | throughput: 0.00 seeds/hr"
|
||||
)
|
||||
|
||||
def write_func(msg):
|
||||
pbar.write(msg)
|
||||
else:
|
||||
persist_print("Progress: (install tqdm for better progress bar)")
|
||||
pbar = None
|
||||
write_func = persist_print
|
||||
|
||||
# Collect results as they complete
|
||||
for i, future in enumerate(future_results):
|
||||
try:
|
||||
seed, success, output, duration, ignored_pattern_idx = future.get()
|
||||
results.append(
|
||||
(seed, success, output, duration, ignored_pattern_idx)
|
||||
)
|
||||
result: FuzzerResult = future.get()
|
||||
results.append(result)
|
||||
|
||||
if ignored_pattern_idx != -1:
|
||||
ignored_seeds.append(seed)
|
||||
ignored_pattern_counts[ignored_pattern_idx] += 1
|
||||
if result.ignored_pattern_idx != -1:
|
||||
ignored_seeds.append(result.seed)
|
||||
ignored_pattern_counts[result.ignored_pattern_idx] += 1
|
||||
ignored_count += 1
|
||||
|
||||
# Only increment failed_count if not ignored
|
||||
if success:
|
||||
if result.success:
|
||||
successful_count += 1
|
||||
elif ignored_pattern_idx == -1:
|
||||
elif result.ignored_pattern_idx == -1:
|
||||
failed_count += 1
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
@ -207,77 +247,48 @@ def run_multi_process_fuzzer(
|
||||
pbar.set_postfix_str(
|
||||
f"{successful_count}/{failed_count}/{ignored_count} | throughput: {throughput:.2f} seeds/hr"
|
||||
)
|
||||
# tqdm automatically shows ETA (estimated time remaining) in the bar_format above
|
||||
pbar.update(1)
|
||||
else:
|
||||
status_emoji = "✅" if success else "❌"
|
||||
ignored_text = " (IGNORED)" if ignored_pattern_idx != -1 else ""
|
||||
status_emoji = "✅" if result.success else "❌"
|
||||
ignored_text = (
|
||||
" (IGNORED)" if result.ignored_pattern_idx != -1 else ""
|
||||
)
|
||||
persist_print(
|
||||
f"Completed {i + 1}/{len(seeds)} - Seed {seed}: {status_emoji}{ignored_text}"
|
||||
f"Completed {i + 1}/{len(seeds)} - Seed {result.seed}: {status_emoji}{ignored_text}"
|
||||
)
|
||||
|
||||
# Only show detailed output for failures (unless verbose)
|
||||
if not success and ignored_pattern_idx == -1:
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.write(
|
||||
f"❌ FAILURE - Seed {seed} (duration: {duration:.2f}s):"
|
||||
)
|
||||
for line in output.split("\n"):
|
||||
if line.strip():
|
||||
pbar.write(f" {line}")
|
||||
pbar.write("") # Empty line
|
||||
else:
|
||||
persist_print(
|
||||
f"❌ FAILURE - Seed {seed} (duration: {duration:.2f}s):"
|
||||
)
|
||||
for line in output.split("\n"):
|
||||
if line.strip():
|
||||
persist_print(f" {line}")
|
||||
persist_print("")
|
||||
elif not success and ignored_pattern_idx != -1:
|
||||
# Optionally, print ignored failures if desired
|
||||
# Unified output handling
|
||||
if not result.success and result.ignored_pattern_idx == -1:
|
||||
handle_result_output(
|
||||
label="❌ FAILURE",
|
||||
seed=result.seed,
|
||||
duration=result.duration,
|
||||
output=result.output,
|
||||
ignored=False,
|
||||
verbose=verbose,
|
||||
write_func=write_func,
|
||||
)
|
||||
elif not result.success and result.ignored_pattern_idx != -1:
|
||||
if verbose:
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.write(
|
||||
f"🚫 IGNORED - Seed {seed} (duration: {duration:.2f}s):"
|
||||
)
|
||||
for line in output.split("\n"):
|
||||
if line.strip():
|
||||
pbar.write(f" {line}")
|
||||
pbar.write("")
|
||||
else:
|
||||
persist_print(
|
||||
f"🚫 IGNORED - Seed {seed} (duration: {duration:.2f}s):"
|
||||
)
|
||||
for line in output.split("\n"):
|
||||
if line.strip():
|
||||
persist_print(f" {line}")
|
||||
persist_print("")
|
||||
handle_result_output(
|
||||
label="🚫 IGNORED",
|
||||
seed=result.seed,
|
||||
duration=result.duration,
|
||||
output=result.output,
|
||||
ignored=True,
|
||||
verbose=verbose,
|
||||
write_func=write_func,
|
||||
)
|
||||
elif verbose:
|
||||
if HAS_TQDM and pbar:
|
||||
ignored_text = (
|
||||
" [IGNORED]" if ignored_pattern_idx != -1 else ""
|
||||
)
|
||||
pbar.write(
|
||||
f"✅ SUCCESS - Seed {seed} (duration: {duration:.2f}s){ignored_text}"
|
||||
)
|
||||
if output.strip():
|
||||
for line in output.split("\n"):
|
||||
if line.strip():
|
||||
pbar.write(f" {line}")
|
||||
pbar.write("")
|
||||
else:
|
||||
ignored_text = (
|
||||
" [IGNORED]" if ignored_pattern_idx != -1 else ""
|
||||
)
|
||||
persist_print(
|
||||
f"✅ SUCCESS - Seed {seed} (duration: {duration:.2f}s){ignored_text}"
|
||||
)
|
||||
if output.strip():
|
||||
for line in output.split("\n"):
|
||||
if line.strip():
|
||||
persist_print(f" {line}")
|
||||
persist_print("")
|
||||
handle_result_output(
|
||||
label="✅ SUCCESS",
|
||||
seed=result.seed,
|
||||
duration=result.duration,
|
||||
output=result.output,
|
||||
ignored=(result.ignored_pattern_idx != -1),
|
||||
verbose=verbose,
|
||||
write_func=write_func,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
@ -290,7 +301,9 @@ def run_multi_process_fuzzer(
|
||||
f"Completed {i + 1}/{len(seeds)} - Seed {seeds[i]}: ❌ POOL ERROR"
|
||||
)
|
||||
persist_print(f"❌ POOL ERROR - Seed {seeds[i]}: {str(e)}")
|
||||
results.append((seeds[i], False, f"Pool error: {str(e)}", 0.0, -1))
|
||||
results.append(
|
||||
FuzzerResult(seeds[i], False, f"Pool error: {str(e)}", 0.0, -1)
|
||||
)
|
||||
|
||||
# Close progress bar
|
||||
if HAS_TQDM and pbar:
|
||||
@ -303,10 +316,12 @@ def run_multi_process_fuzzer(
|
||||
persist_print("📈 SUMMARY (partial, interrupted)")
|
||||
persist_print("=" * 60)
|
||||
|
||||
successful = [r for r in results if r[1]]
|
||||
successful = [res for res in results if res.success]
|
||||
# Only count as failed if not ignored
|
||||
failed = [r for r in results if not r[1] and r[4] == -1]
|
||||
ignored = [r for r in results if r[4] != -1]
|
||||
failed = [
|
||||
res for res in results if not res.success and res.ignored_pattern_idx == -1
|
||||
]
|
||||
ignored = [res for res in results if res.ignored_pattern_idx != -1]
|
||||
|
||||
persist_print(
|
||||
f"✅ Successful: {len(successful)}/{len(results)} ({(len(successful) / len(results) * 100 if results else 0):.1f}%)"
|
||||
@ -322,13 +337,13 @@ def run_multi_process_fuzzer(
|
||||
else "⚡ Throughput: N/A"
|
||||
)
|
||||
if failed:
|
||||
persist_print(f"\n❌ Failed seeds: {[r[0] for r in failed]}")
|
||||
persist_print(f"\n❌ Failed seeds: {[res.seed for res in failed]}")
|
||||
if successful:
|
||||
persist_print(f"✅ Successful seeds: {[r[0] for r in successful]}")
|
||||
avg_success_time = sum(r[3] for r in successful) / len(successful)
|
||||
persist_print(f"✅ Successful seeds: {[res.seed for res in successful]}")
|
||||
avg_success_time = sum(res.duration for res in successful) / len(successful)
|
||||
persist_print(f"⚡ Avg time for successful runs: {avg_success_time:.2f}s")
|
||||
if ignored:
|
||||
persist_print(f"\n🚫 Ignored seeds: {[r[0] for r in ignored]}")
|
||||
persist_print(f"\n🚫 Ignored seeds: {[res.seed for res in ignored]}")
|
||||
# Print ignore pattern stats
|
||||
persist_print("\n🚫 Ignored pattern statistics:")
|
||||
total_ignored = len(ignored)
|
||||
@ -348,10 +363,12 @@ def run_multi_process_fuzzer(
|
||||
persist_print("📈 SUMMARY")
|
||||
persist_print("=" * 60)
|
||||
|
||||
successful = [r for r in results if r[1]]
|
||||
successful = [res for res in results if res.success]
|
||||
# Only count as failed if not ignored
|
||||
failed = [r for r in results if not r[1] and r[4] == -1]
|
||||
ignored = [r for r in results if r[4] != -1]
|
||||
failed = [
|
||||
res for res in results if not res.success and res.ignored_pattern_idx == -1
|
||||
]
|
||||
ignored = [res for res in results if res.ignored_pattern_idx != -1]
|
||||
|
||||
persist_print(
|
||||
f"✅ Successful: {len(successful)}/{len(results)} ({len(successful) / len(results) * 100:.1f}%)"
|
||||
@ -367,15 +384,15 @@ def run_multi_process_fuzzer(
|
||||
)
|
||||
|
||||
if failed:
|
||||
persist_print(f"\n❌ Failed seeds: {[r[0] for r in failed]}")
|
||||
persist_print(f"\n❌ Failed seeds: {[res.seed for res in failed]}")
|
||||
|
||||
if successful:
|
||||
persist_print(f"✅ Successful seeds: {[r[0] for r in successful]}")
|
||||
avg_success_time = sum(r[3] for r in successful) / len(successful)
|
||||
persist_print(f"✅ Successful seeds: {[res.seed for res in successful]}")
|
||||
avg_success_time = sum(res.duration for res in successful) / len(successful)
|
||||
persist_print(f"⚡ Avg time for successful runs: {avg_success_time:.2f}s")
|
||||
|
||||
if ignored:
|
||||
persist_print(f"\n🚫 Ignored seeds: {[r[0] for r in ignored]}")
|
||||
persist_print(f"\n🚫 Ignored seeds: {[res.seed for res in ignored]}")
|
||||
# Print ignore pattern stats
|
||||
persist_print("\n🚫 Ignored pattern statistics:")
|
||||
total_ignored = len(ignored)
|
||||
|
Reference in New Issue
Block a user