[torchfuzz] add support for --stop-at-first-failure flag (#165529)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165529
Approved by: https://github.com/pianpwk
ghstack dependencies: #164749
This commit is contained in:
bobrenjc93
2025-10-16 14:11:01 -07:00
committed by PyTorch MergeBot
parent d0add0be43
commit 7dabfb07cb
3 changed files with 173 additions and 4 deletions

View File

@ -196,7 +196,7 @@ class FuzzTemplate:
class DefaultFuzzTemplate(FuzzTemplate):
def __init__(self):
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
super().__init__(
supported_ops=[
@ -236,7 +236,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
# Regularization
"torch.nn.functional.dropout",
],
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
check=EagerVsFullGraphDynamicCompileCheck(),
)
def spec_distribution(self):

View File

@ -241,7 +241,7 @@ if __name__ == "__main__":
import argparse
try:
from multi_process_fuzzer import run_multi_process_fuzzer
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
except ImportError:
# If importing as a module fails, import from the same directory
import os
@ -249,7 +249,7 @@ if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
from multi_process_fuzzer import run_multi_process_fuzzer
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
# Set up command-line argument parsing
parser = argparse.ArgumentParser(
@ -296,6 +296,11 @@ if __name__ == "__main__":
action="store_true",
help="Print detailed output for all runs (not just failures)",
)
parser.add_argument(
"--stop-at-first-failure",
action="store_true",
help="Pick a random seed and keep iterating until finding a failure (exits with non-zero code)",
)
# Legacy arguments
parser.add_argument(
@ -337,6 +342,30 @@ if __name__ == "__main__":
supported_ops=parsed_supported_ops,
op_weights=(parsed_weights if parsed_weights else None),
)
elif args.stop_at_first_failure:
# Stop-at-first-failure mode
# Default number of processes
if args.processes is None:
cpu_count = mp.cpu_count()
args.processes = max(1, min(16, int(cpu_count * 0.75)))
if args.processes < 1:
print("❌ Error: Number of processes must be at least 1")
sys.exit(1)
try:
run_until_failure(
num_processes=args.processes,
verbose=args.verbose,
template=args.template,
supported_ops=args.supported_ops,
)
except Exception as e:
print(f"❌ Unexpected error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
elif args.start is not None or args.count is not None:
# Multi-process fuzzing mode
if args.start is None:

View File

@ -522,3 +522,143 @@ def _print_operation_distribution(results: list[FuzzerResult]) -> None:
persist_print(
"\n📊 No operation statistics collected (no successful runs with stats)"
)
def run_until_failure(
num_processes: Optional[int] = None,
verbose: bool = False,
template: str = "default",
supported_ops: Optional[str] = None,
) -> None:
"""
Run the multi-process fuzzer with a random starting seed, iterating until a failure is found.
Args:
num_processes: Number of worker processes to use
verbose: Whether to print detailed output
template: The template to use for code generation
supported_ops: Comma-separated ops string with optional weights
Returns:
Exits with non-zero code when a failure is found
"""
import random
# Pick a random seed to start from
initial_seed = random.randint(0, 2**31 - 1)
persist_print(
f"🎲 Starting continuous fuzzing with random initial seed: {initial_seed}"
)
persist_print(f"🚀 Using {num_processes} processes")
persist_print(
f"🔧 Command template: python fuzzer.py --seed {{seed}} --template {template}"
)
persist_print("🎯 Running until first failure is found...")
persist_print("=" * 60)
start_time = time.time()
current_seed = initial_seed
total_successful = 0
total_ignored = 0
batch_size = 100 # Process seeds in batches of 100
try:
while True:
# Process a batch of seeds
seeds = list(range(current_seed, current_seed + batch_size))
with mp.Pool(processes=num_processes) as pool:
future_results = []
for seed in seeds:
future = pool.apply_async(
run_fuzzer_with_seed, (seed, template, supported_ops)
)
future_results.append((seed, future))
# Set up progress bar for this batch
if HAS_TQDM:
from tqdm import tqdm
pbar = tqdm(
total=len(seeds),
desc=f"Batch starting at seed {current_seed}",
file=sys.stdout,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}] ✅/🚫={postfix}",
dynamic_ncols=True,
)
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
def write_func(msg):
pbar.write(msg)
else:
pbar = None
# Collect results as they complete
for seed, future in future_results:
result: FuzzerResult = future.get()
if result.ignored_pattern_idx != -1:
total_ignored += 1
if result.success:
total_successful += 1
elif result.ignored_pattern_idx == -1:
# Found a failure that is not ignored!
if HAS_TQDM and pbar:
pbar.close()
elapsed = time.time() - start_time
persist_print("\n" + "=" * 60)
persist_print("🎯 FAILURE FOUND!")
persist_print("=" * 60)
persist_print(f"❌ Failing seed: {result.seed}")
persist_print(
f"⏱️ Duration for this seed: {result.duration:.2f}s"
)
persist_print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
persist_print(f"✅ Successful seeds tested: {total_successful}")
persist_print(f"🚫 Ignored seeds: {total_ignored}")
persist_print(
f"📊 Total seeds tested: {total_successful + total_ignored + 1}"
)
persist_print("\n💥 Failure output:")
persist_print("-" * 60)
print_output_lines(result.output, persist_print)
persist_print("-" * 60)
persist_print(
f"\n🔄 Reproduce with: python fuzzer.py --seed {result.seed} --template {template}"
)
# Exit with non-zero code
sys.exit(1)
# Update progress bar
if HAS_TQDM and pbar:
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
pbar.update(1)
elif verbose:
status_emoji = "" if result.success else "🚫"
persist_print(f"Seed {result.seed}: {status_emoji}")
# Close progress bar for this batch
if HAS_TQDM and pbar:
pbar.close()
# Move to next batch
current_seed += batch_size
except KeyboardInterrupt:
persist_print("\n🛑 Interrupted by user (Ctrl+C)")
elapsed = time.time() - start_time
persist_print("=" * 60)
persist_print("📈 SUMMARY (interrupted)")
persist_print("=" * 60)
persist_print(f"⏱️ Total time: {elapsed:.2f}s")
persist_print(f"✅ Successful seeds: {total_successful}")
persist_print(f"🚫 Ignored seeds: {total_ignored}")
persist_print(f"📊 Total seeds tested: {total_successful + total_ignored}")
persist_print(
f"⚡ Throughput: {((total_successful + total_ignored) / (elapsed / 3600)):.2f} seeds/hr"
)
sys.exit(130)