[torchfuzz] make fuzzer deterministic (#164397)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164397
Approved by: https://github.com/pianpwk
ghstack dependencies: #164034, #164209, #164211, #164210
This commit is contained in:
bobrenjc93
2025-10-01 15:59:15 -07:00
committed by PyTorch MergeBot
parent 5dbae1eae2
commit 144378615a
5 changed files with 214 additions and 62 deletions

View File

@ -648,13 +648,13 @@ def create_program_file(python_code: str) -> str:
Returns:
Path to the created temporary file
"""
import random
import hashlib
# Generate a random nonce for the filename
nonce = random.randint(0, 1_000_000_000)
# Generate a deterministic filename based on code content hash
code_hash = hashlib.md5(python_code.encode()).hexdigest()[:8] # noqa: S324
tmp_dir = "/tmp/torchfuzz"
os.makedirs(tmp_dir, exist_ok=True)
generated_file_path = os.path.join(tmp_dir, f"fuzz_{nonce}.py")
generated_file_path = os.path.join(tmp_dir, f"fuzz_{code_hash}.py")
# Write the generated code to the specified file
with open(generated_file_path, "w") as f:

View File

@ -41,8 +41,10 @@ class ConstantOperator(Operator):
) -> str:
"""Generate code for constant creation."""
# Create constant by calling fuzzing functions during codegen with deterministic seed
# Use a deterministic seed based on the variable name to ensure reproducibility
var_seed = hash(output_name) % (2**31)
# Use a deterministic hash based on the variable name to ensure reproducibility across processes
import hashlib
var_seed = int(hashlib.md5(output_name.encode()).hexdigest()[:8], 16) % (2**31) # noqa: S324
if isinstance(output_spec, ScalarSpec):
# Call fuzz_scalar during codegen and embed the result

View File

@ -377,8 +377,11 @@ def fuzz_operation_graph(
random.seed(seed)
torch.manual_seed(seed)
# Reset global arg counter for deterministic behavior
global _next_arg_id
_next_arg_id = 0
# Global counter for unique node IDs
# Global counter for unique node IDs - start from 0 for deterministic behavior
node_counter = 0
# Dictionary to store all nodes: node_id -> OperationNode

View File

@ -357,55 +357,69 @@ def fuzz_tensor(
if seed is None:
seed = random.randint(0, 2**32 - 1)
# Set the random seed for reproducibility
# Create a local Random instance to avoid interfering with global state
local_random = random.Random(seed)
# Set the torch random seed for reproducibility
# Save and restore global torch state to avoid side effects
torch_state = torch.get_rng_state()
torch.manual_seed(seed)
random.seed(seed)
# Generate random values if not provided
if size is None:
size = fuzz_tensor_size()
# Generate random values if not provided using local random instance
old_random_state = random.getstate()
try:
# Temporarily use local random instance for deterministic generation
random.setstate(local_random.getstate())
if dtype is None:
dtype = fuzz_torch_tensor_type("default")
if size is None:
size = fuzz_tensor_size()
if stride is None:
stride = fuzz_valid_stride(size)
if dtype is None:
dtype = fuzz_torch_tensor_type("default")
# Handle empty tensor case
if len(size) == 0:
return torch.ones((), dtype=dtype), seed
if stride is None:
stride = fuzz_valid_stride(size)
# Calculate required storage size for the custom stride
required_storage = _compute_storage_size_needed(size, stride)
# Handle empty tensor case
if len(size) == 0:
return torch.ones((), dtype=dtype), seed
# Create base tensor with sufficient storage
if FuzzerConfig.use_real_values:
# Use random values based on dtype
if dtype.is_floating_point:
base_tensor = torch.randn(required_storage, dtype=dtype)
elif dtype in [torch.complex64, torch.complex128]:
# Create complex tensor with random real and imaginary parts
real_part = torch.randn(
required_storage,
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
)
imag_part = torch.randn(
required_storage,
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
)
base_tensor = torch.complex(real_part, imag_part).to(dtype)
elif dtype == torch.bool:
base_tensor = torch.randint(0, 2, (required_storage,), dtype=torch.bool)
else: # integer types
base_tensor = torch.randint(-100, 100, (required_storage,), dtype=dtype)
else:
# Use zeros (default behavior)
base_tensor = torch.ones(required_storage, dtype=dtype)
# Calculate required storage size for the custom stride
required_storage = _compute_storage_size_needed(size, stride)
# Create strided tensor view
strided_tensor = torch.as_strided(base_tensor, size, stride)
# Create base tensor with sufficient storage
if FuzzerConfig.use_real_values:
# Use random values based on dtype
if dtype.is_floating_point:
base_tensor = torch.randn(required_storage, dtype=dtype)
elif dtype in [torch.complex64, torch.complex128]:
# Create complex tensor with random real and imaginary parts
real_part = torch.randn(
required_storage,
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
)
imag_part = torch.randn(
required_storage,
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
)
base_tensor = torch.complex(real_part, imag_part).to(dtype)
elif dtype == torch.bool:
base_tensor = torch.randint(0, 2, (required_storage,), dtype=torch.bool)
else: # integer types
base_tensor = torch.randint(-100, 100, (required_storage,), dtype=dtype)
else:
# Use zeros (default behavior)
base_tensor = torch.ones(required_storage, dtype=dtype)
return strided_tensor, seed
# Create strided tensor view
strided_tensor = torch.as_strided(base_tensor, size, stride)
return strided_tensor, seed
finally:
# Restore original random state
random.setstate(old_random_state)
# Restore original torch state
torch.set_rng_state(torch_state)
def fuzz_tensor_simple(
@ -493,23 +507,49 @@ def fuzz_scalar(spec, seed: Optional[int] = None) -> Union[float, int, bool, com
if spec.constant is not None:
return spec.constant
# Set seed for reproducibility if provided
# Create a local random instance to avoid interfering with global state
if seed is not None:
random.seed(seed)
local_random = random.Random(seed)
# Save and restore global random state
old_random_state = random.getstate()
try:
random.setstate(local_random.getstate())
# Create a scalar value based on dtype
if spec.dtype.is_floating_point:
return random.uniform(-10.0, 10.0)
elif spec.dtype in [torch.complex64, torch.complex128]:
# Only generate complex values if not avoiding complex dtypes
if FuzzerConfig.avoid_complex:
raise ValueError("Cannot generate complex values with avoid_complex=True")
return complex(random.uniform(-10.0, 10.0), random.uniform(-10.0, 10.0))
else: # integer or bool
if spec.dtype == torch.bool:
return random.choice([True, False])
else:
return random.randint(-10, 10)
# Create a scalar value based on dtype
if spec.dtype.is_floating_point:
return random.uniform(-10.0, 10.0)
elif spec.dtype in [torch.complex64, torch.complex128]:
# Only generate complex values if not avoiding complex dtypes
if FuzzerConfig.avoid_complex:
raise ValueError(
"Cannot generate complex values with avoid_complex=True"
)
return complex(random.uniform(-10.0, 10.0), random.uniform(-10.0, 10.0))
else: # integer or bool
if spec.dtype == torch.bool:
return random.choice([True, False])
else:
return random.randint(-10, 10)
finally:
# Restore original random state
random.setstate(old_random_state)
else:
# Use current random state when no seed provided
# Create a scalar value based on dtype
if spec.dtype.is_floating_point:
return random.uniform(-10.0, 10.0)
elif spec.dtype in [torch.complex64, torch.complex128]:
# Only generate complex values if not avoiding complex dtypes
if FuzzerConfig.avoid_complex:
raise ValueError(
"Cannot generate complex values with avoid_complex=True"
)
return complex(random.uniform(-10.0, 10.0), random.uniform(-10.0, 10.0))
else: # integer or bool
if spec.dtype == torch.bool:
return random.choice([True, False])
else:
return random.randint(-10, 10)
def specs_compatible(spec1: Spec, spec2: Spec) -> bool:

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
"""Test to verify fuzzer produces deterministic output with same seed."""
import subprocess
import sys
from pathlib import Path
def run_fuzzer_with_seed(seed):
"""Run the fuzzer with a specific seed and return the generated code."""
cmd = [sys.executable, "fuzzer.py", "--seed", str(seed)]
# Clear the output directory first
torchfuzz_dir = Path("/tmp/torchfuzz")
if torchfuzz_dir.exists():
for f in torchfuzz_dir.glob("*.py"):
f.unlink()
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=Path(__file__).parent
)
if result.returncode != 0:
print(f"Fuzzer failed with return code {result.returncode}")
print(f"stdout: {result.stdout}")
print(f"stderr: {result.stderr}")
return None
# Find the generated Python file in /tmp/torchfuzz/
py_files = list(torchfuzz_dir.glob("fuzz_*.py"))
if not py_files:
print("No Python files generated in /tmp/torchfuzz/")
return None
# Read the content of the generated file
with open(py_files[0]) as f:
return f.read()
def test_deterministic_output():
"""Test that the fuzzer produces identical output for the same seed."""
seed = 115306 # Use the seed mentioned in the user's issue
num_runs = 3
outputs = []
print(f"Running fuzzer {num_runs} times with seed {seed}...")
for i in range(num_runs):
print(f"Run {i + 1}...")
output = run_fuzzer_with_seed(seed)
if output is None:
print(f"Failed to get output from run {i + 1}")
return False
outputs.append(output)
# Compare all outputs
first_output = outputs[0]
all_identical = all(output == first_output for output in outputs[1:])
if all_identical:
print("✓ SUCCESS: All outputs are identical!")
print(f"Generated code length: {len(first_output)} characters")
return True
else:
print("✗ FAILURE: Outputs differ between runs!")
# Show differences for debugging
for i, output in enumerate(outputs[1:], 2):
if output != first_output:
print(f"\nDifferences between run 1 and run {i}:")
# Simple line-by-line comparison
lines1 = first_output.splitlines()
lines2 = output.splitlines()
min_lines = min(len(lines1), len(lines2))
for line_num in range(min_lines):
if lines1[line_num] != lines2[line_num]:
print(f"Line {line_num + 1}:")
print(f" Run 1: {lines1[line_num]}")
print(f" Run {i}: {lines2[line_num]}")
break
if len(lines1) != len(lines2):
print(f"Different number of lines: {len(lines1)} vs {len(lines2)}")
return False
def main():
"""Main function to run the determinism test."""
print("Testing fuzzer determinism...")
print("=" * 50)
success = test_deterministic_output()
if success:
print("\n🎉 Test PASSED: Fuzzer is deterministic!")
sys.exit(0)
else:
print("\n❌ Test FAILED: Fuzzer is not deterministic!")
sys.exit(1)
if __name__ == "__main__":
main()