mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torchfuzz] add support for --stop-at-first-failure flag (#165529)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165529 Approved by: https://github.com/pianpwk ghstack dependencies: #164749
This commit is contained in:
committed by
PyTorch MergeBot
parent
d0add0be43
commit
7dabfb07cb
@ -196,7 +196,7 @@ class FuzzTemplate:
|
||||
|
||||
class DefaultFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
||||
|
||||
super().__init__(
|
||||
supported_ops=[
|
||||
@ -236,7 +236,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
# Regularization
|
||||
"torch.nn.functional.dropout",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
|
||||
def spec_distribution(self):
|
||||
|
@ -241,7 +241,7 @@ if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
try:
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
|
||||
except ImportError:
|
||||
# If importing as a module fails, import from the same directory
|
||||
import os
|
||||
@ -249,7 +249,7 @@ if __name__ == "__main__":
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
|
||||
|
||||
# Set up command-line argument parsing
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -296,6 +296,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Print detailed output for all runs (not just failures)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop-at-first-failure",
|
||||
action="store_true",
|
||||
help="Pick a random seed and keep iterating until finding a failure (exits with non-zero code)",
|
||||
)
|
||||
|
||||
# Legacy arguments
|
||||
parser.add_argument(
|
||||
@ -337,6 +342,30 @@ if __name__ == "__main__":
|
||||
supported_ops=parsed_supported_ops,
|
||||
op_weights=(parsed_weights if parsed_weights else None),
|
||||
)
|
||||
elif args.stop_at_first_failure:
|
||||
# Stop-at-first-failure mode
|
||||
# Default number of processes
|
||||
if args.processes is None:
|
||||
cpu_count = mp.cpu_count()
|
||||
args.processes = max(1, min(16, int(cpu_count * 0.75)))
|
||||
|
||||
if args.processes < 1:
|
||||
print("❌ Error: Number of processes must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
run_until_failure(
|
||||
num_processes=args.processes,
|
||||
verbose=args.verbose,
|
||||
template=args.template,
|
||||
supported_ops=args.supported_ops,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
elif args.start is not None or args.count is not None:
|
||||
# Multi-process fuzzing mode
|
||||
if args.start is None:
|
||||
|
@ -522,3 +522,143 @@ def _print_operation_distribution(results: list[FuzzerResult]) -> None:
|
||||
persist_print(
|
||||
"\n📊 No operation statistics collected (no successful runs with stats)"
|
||||
)
|
||||
|
||||
|
||||
def run_until_failure(
|
||||
num_processes: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
template: str = "default",
|
||||
supported_ops: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the multi-process fuzzer with a random starting seed, iterating until a failure is found.
|
||||
|
||||
Args:
|
||||
num_processes: Number of worker processes to use
|
||||
verbose: Whether to print detailed output
|
||||
template: The template to use for code generation
|
||||
supported_ops: Comma-separated ops string with optional weights
|
||||
|
||||
Returns:
|
||||
Exits with non-zero code when a failure is found
|
||||
"""
|
||||
import random
|
||||
|
||||
# Pick a random seed to start from
|
||||
initial_seed = random.randint(0, 2**31 - 1)
|
||||
|
||||
persist_print(
|
||||
f"🎲 Starting continuous fuzzing with random initial seed: {initial_seed}"
|
||||
)
|
||||
persist_print(f"🚀 Using {num_processes} processes")
|
||||
persist_print(
|
||||
f"🔧 Command template: python fuzzer.py --seed {{seed}} --template {template}"
|
||||
)
|
||||
persist_print("🎯 Running until first failure is found...")
|
||||
persist_print("=" * 60)
|
||||
|
||||
start_time = time.time()
|
||||
current_seed = initial_seed
|
||||
total_successful = 0
|
||||
total_ignored = 0
|
||||
batch_size = 100 # Process seeds in batches of 100
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Process a batch of seeds
|
||||
seeds = list(range(current_seed, current_seed + batch_size))
|
||||
|
||||
with mp.Pool(processes=num_processes) as pool:
|
||||
future_results = []
|
||||
for seed in seeds:
|
||||
future = pool.apply_async(
|
||||
run_fuzzer_with_seed, (seed, template, supported_ops)
|
||||
)
|
||||
future_results.append((seed, future))
|
||||
|
||||
# Set up progress bar for this batch
|
||||
if HAS_TQDM:
|
||||
from tqdm import tqdm
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(seeds),
|
||||
desc=f"Batch starting at seed {current_seed}",
|
||||
file=sys.stdout,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}] ✅/🚫={postfix}",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
|
||||
|
||||
def write_func(msg):
|
||||
pbar.write(msg)
|
||||
else:
|
||||
pbar = None
|
||||
|
||||
# Collect results as they complete
|
||||
for seed, future in future_results:
|
||||
result: FuzzerResult = future.get()
|
||||
|
||||
if result.ignored_pattern_idx != -1:
|
||||
total_ignored += 1
|
||||
|
||||
if result.success:
|
||||
total_successful += 1
|
||||
elif result.ignored_pattern_idx == -1:
|
||||
# Found a failure that is not ignored!
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.close()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
persist_print("\n" + "=" * 60)
|
||||
persist_print("🎯 FAILURE FOUND!")
|
||||
persist_print("=" * 60)
|
||||
persist_print(f"❌ Failing seed: {result.seed}")
|
||||
persist_print(
|
||||
f"⏱️ Duration for this seed: {result.duration:.2f}s"
|
||||
)
|
||||
persist_print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
|
||||
persist_print(f"✅ Successful seeds tested: {total_successful}")
|
||||
persist_print(f"🚫 Ignored seeds: {total_ignored}")
|
||||
persist_print(
|
||||
f"📊 Total seeds tested: {total_successful + total_ignored + 1}"
|
||||
)
|
||||
persist_print("\n💥 Failure output:")
|
||||
persist_print("-" * 60)
|
||||
print_output_lines(result.output, persist_print)
|
||||
persist_print("-" * 60)
|
||||
persist_print(
|
||||
f"\n🔄 Reproduce with: python fuzzer.py --seed {result.seed} --template {template}"
|
||||
)
|
||||
|
||||
# Exit with non-zero code
|
||||
sys.exit(1)
|
||||
|
||||
# Update progress bar
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
|
||||
pbar.update(1)
|
||||
elif verbose:
|
||||
status_emoji = "✅" if result.success else "🚫"
|
||||
persist_print(f"Seed {result.seed}: {status_emoji}")
|
||||
|
||||
# Close progress bar for this batch
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.close()
|
||||
|
||||
# Move to next batch
|
||||
current_seed += batch_size
|
||||
|
||||
except KeyboardInterrupt:
|
||||
persist_print("\n🛑 Interrupted by user (Ctrl+C)")
|
||||
elapsed = time.time() - start_time
|
||||
persist_print("=" * 60)
|
||||
persist_print("📈 SUMMARY (interrupted)")
|
||||
persist_print("=" * 60)
|
||||
persist_print(f"⏱️ Total time: {elapsed:.2f}s")
|
||||
persist_print(f"✅ Successful seeds: {total_successful}")
|
||||
persist_print(f"🚫 Ignored seeds: {total_ignored}")
|
||||
persist_print(f"📊 Total seeds tested: {total_successful + total_ignored}")
|
||||
persist_print(
|
||||
f"⚡ Throughput: {((total_successful + total_ignored) / (elapsed / 3600)):.2f} seeds/hr"
|
||||
)
|
||||
sys.exit(130)
|
||||
|
Reference in New Issue
Block a user