mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] add support for --stop-at-first-failure flag (#165529)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165529 Approved by: https://github.com/pianpwk ghstack dependencies: #164749
This commit is contained in:
committed by
PyTorch MergeBot
parent
d0add0be43
commit
7dabfb07cb
@ -196,7 +196,7 @@ class FuzzTemplate:
|
|||||||
|
|
||||||
class DefaultFuzzTemplate(FuzzTemplate):
|
class DefaultFuzzTemplate(FuzzTemplate):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
|
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
supported_ops=[
|
supported_ops=[
|
||||||
@ -236,7 +236,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
|||||||
# Regularization
|
# Regularization
|
||||||
"torch.nn.functional.dropout",
|
"torch.nn.functional.dropout",
|
||||||
],
|
],
|
||||||
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
|
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def spec_distribution(self):
|
def spec_distribution(self):
|
||||||
|
@ -241,7 +241,7 @@ if __name__ == "__main__":
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from multi_process_fuzzer import run_multi_process_fuzzer
|
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# If importing as a module fails, import from the same directory
|
# If importing as a module fails, import from the same directory
|
||||||
import os
|
import os
|
||||||
@ -249,7 +249,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.insert(0, current_dir)
|
sys.path.insert(0, current_dir)
|
||||||
from multi_process_fuzzer import run_multi_process_fuzzer
|
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
|
||||||
|
|
||||||
# Set up command-line argument parsing
|
# Set up command-line argument parsing
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -296,6 +296,11 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Print detailed output for all runs (not just failures)",
|
help="Print detailed output for all runs (not just failures)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop-at-first-failure",
|
||||||
|
action="store_true",
|
||||||
|
help="Pick a random seed and keep iterating until finding a failure (exits with non-zero code)",
|
||||||
|
)
|
||||||
|
|
||||||
# Legacy arguments
|
# Legacy arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -337,6 +342,30 @@ if __name__ == "__main__":
|
|||||||
supported_ops=parsed_supported_ops,
|
supported_ops=parsed_supported_ops,
|
||||||
op_weights=(parsed_weights if parsed_weights else None),
|
op_weights=(parsed_weights if parsed_weights else None),
|
||||||
)
|
)
|
||||||
|
elif args.stop_at_first_failure:
|
||||||
|
# Stop-at-first-failure mode
|
||||||
|
# Default number of processes
|
||||||
|
if args.processes is None:
|
||||||
|
cpu_count = mp.cpu_count()
|
||||||
|
args.processes = max(1, min(16, int(cpu_count * 0.75)))
|
||||||
|
|
||||||
|
if args.processes < 1:
|
||||||
|
print("❌ Error: Number of processes must be at least 1")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_until_failure(
|
||||||
|
num_processes=args.processes,
|
||||||
|
verbose=args.verbose,
|
||||||
|
template=args.template,
|
||||||
|
supported_ops=args.supported_ops,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Unexpected error: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
elif args.start is not None or args.count is not None:
|
elif args.start is not None or args.count is not None:
|
||||||
# Multi-process fuzzing mode
|
# Multi-process fuzzing mode
|
||||||
if args.start is None:
|
if args.start is None:
|
||||||
|
@ -522,3 +522,143 @@ def _print_operation_distribution(results: list[FuzzerResult]) -> None:
|
|||||||
persist_print(
|
persist_print(
|
||||||
"\n📊 No operation statistics collected (no successful runs with stats)"
|
"\n📊 No operation statistics collected (no successful runs with stats)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_until_failure(
|
||||||
|
num_processes: Optional[int] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
template: str = "default",
|
||||||
|
supported_ops: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run the multi-process fuzzer with a random starting seed, iterating until a failure is found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_processes: Number of worker processes to use
|
||||||
|
verbose: Whether to print detailed output
|
||||||
|
template: The template to use for code generation
|
||||||
|
supported_ops: Comma-separated ops string with optional weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exits with non-zero code when a failure is found
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
# Pick a random seed to start from
|
||||||
|
initial_seed = random.randint(0, 2**31 - 1)
|
||||||
|
|
||||||
|
persist_print(
|
||||||
|
f"🎲 Starting continuous fuzzing with random initial seed: {initial_seed}"
|
||||||
|
)
|
||||||
|
persist_print(f"🚀 Using {num_processes} processes")
|
||||||
|
persist_print(
|
||||||
|
f"🔧 Command template: python fuzzer.py --seed {{seed}} --template {template}"
|
||||||
|
)
|
||||||
|
persist_print("🎯 Running until first failure is found...")
|
||||||
|
persist_print("=" * 60)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
current_seed = initial_seed
|
||||||
|
total_successful = 0
|
||||||
|
total_ignored = 0
|
||||||
|
batch_size = 100 # Process seeds in batches of 100
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Process a batch of seeds
|
||||||
|
seeds = list(range(current_seed, current_seed + batch_size))
|
||||||
|
|
||||||
|
with mp.Pool(processes=num_processes) as pool:
|
||||||
|
future_results = []
|
||||||
|
for seed in seeds:
|
||||||
|
future = pool.apply_async(
|
||||||
|
run_fuzzer_with_seed, (seed, template, supported_ops)
|
||||||
|
)
|
||||||
|
future_results.append((seed, future))
|
||||||
|
|
||||||
|
# Set up progress bar for this batch
|
||||||
|
if HAS_TQDM:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
pbar = tqdm(
|
||||||
|
total=len(seeds),
|
||||||
|
desc=f"Batch starting at seed {current_seed}",
|
||||||
|
file=sys.stdout,
|
||||||
|
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}] ✅/🚫={postfix}",
|
||||||
|
dynamic_ncols=True,
|
||||||
|
)
|
||||||
|
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
|
||||||
|
|
||||||
|
def write_func(msg):
|
||||||
|
pbar.write(msg)
|
||||||
|
else:
|
||||||
|
pbar = None
|
||||||
|
|
||||||
|
# Collect results as they complete
|
||||||
|
for seed, future in future_results:
|
||||||
|
result: FuzzerResult = future.get()
|
||||||
|
|
||||||
|
if result.ignored_pattern_idx != -1:
|
||||||
|
total_ignored += 1
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
total_successful += 1
|
||||||
|
elif result.ignored_pattern_idx == -1:
|
||||||
|
# Found a failure that is not ignored!
|
||||||
|
if HAS_TQDM and pbar:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
persist_print("\n" + "=" * 60)
|
||||||
|
persist_print("🎯 FAILURE FOUND!")
|
||||||
|
persist_print("=" * 60)
|
||||||
|
persist_print(f"❌ Failing seed: {result.seed}")
|
||||||
|
persist_print(
|
||||||
|
f"⏱️ Duration for this seed: {result.duration:.2f}s"
|
||||||
|
)
|
||||||
|
persist_print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
|
||||||
|
persist_print(f"✅ Successful seeds tested: {total_successful}")
|
||||||
|
persist_print(f"🚫 Ignored seeds: {total_ignored}")
|
||||||
|
persist_print(
|
||||||
|
f"📊 Total seeds tested: {total_successful + total_ignored + 1}"
|
||||||
|
)
|
||||||
|
persist_print("\n💥 Failure output:")
|
||||||
|
persist_print("-" * 60)
|
||||||
|
print_output_lines(result.output, persist_print)
|
||||||
|
persist_print("-" * 60)
|
||||||
|
persist_print(
|
||||||
|
f"\n🔄 Reproduce with: python fuzzer.py --seed {result.seed} --template {template}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exit with non-zero code
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
if HAS_TQDM and pbar:
|
||||||
|
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
|
||||||
|
pbar.update(1)
|
||||||
|
elif verbose:
|
||||||
|
status_emoji = "✅" if result.success else "🚫"
|
||||||
|
persist_print(f"Seed {result.seed}: {status_emoji}")
|
||||||
|
|
||||||
|
# Close progress bar for this batch
|
||||||
|
if HAS_TQDM and pbar:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
# Move to next batch
|
||||||
|
current_seed += batch_size
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
persist_print("\n🛑 Interrupted by user (Ctrl+C)")
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
persist_print("=" * 60)
|
||||||
|
persist_print("📈 SUMMARY (interrupted)")
|
||||||
|
persist_print("=" * 60)
|
||||||
|
persist_print(f"⏱️ Total time: {elapsed:.2f}s")
|
||||||
|
persist_print(f"✅ Successful seeds: {total_successful}")
|
||||||
|
persist_print(f"🚫 Ignored seeds: {total_ignored}")
|
||||||
|
persist_print(f"📊 Total seeds tested: {total_successful + total_ignored}")
|
||||||
|
persist_print(
|
||||||
|
f"⚡ Throughput: {((total_successful + total_ignored) / (elapsed / 3600)):.2f} seeds/hr"
|
||||||
|
)
|
||||||
|
sys.exit(130)
|
||||||
|
Reference in New Issue
Block a user