mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torchfuzz] make fuzzer deterministic (#164397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164397 Approved by: https://github.com/pianpwk ghstack dependencies: #164034, #164209, #164211, #164210
This commit is contained in:
committed by
PyTorch MergeBot
parent
5dbae1eae2
commit
144378615a
@ -648,13 +648,13 @@ def create_program_file(python_code: str) -> str:
|
||||
Returns:
|
||||
Path to the created temporary file
|
||||
"""
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
# Generate a random nonce for the filename
|
||||
nonce = random.randint(0, 1_000_000_000)
|
||||
# Generate a deterministic filename based on code content hash
|
||||
code_hash = hashlib.md5(python_code.encode()).hexdigest()[:8] # noqa: S324
|
||||
tmp_dir = "/tmp/torchfuzz"
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
generated_file_path = os.path.join(tmp_dir, f"fuzz_{nonce}.py")
|
||||
generated_file_path = os.path.join(tmp_dir, f"fuzz_{code_hash}.py")
|
||||
|
||||
# Write the generated code to the specified file
|
||||
with open(generated_file_path, "w") as f:
|
||||
|
@ -41,8 +41,10 @@ class ConstantOperator(Operator):
|
||||
) -> str:
|
||||
"""Generate code for constant creation."""
|
||||
# Create constant by calling fuzzing functions during codegen with deterministic seed
|
||||
# Use a deterministic seed based on the variable name to ensure reproducibility
|
||||
var_seed = hash(output_name) % (2**31)
|
||||
# Use a deterministic hash based on the variable name to ensure reproducibility across processes
|
||||
import hashlib
|
||||
|
||||
var_seed = int(hashlib.md5(output_name.encode()).hexdigest()[:8], 16) % (2**31) # noqa: S324
|
||||
|
||||
if isinstance(output_spec, ScalarSpec):
|
||||
# Call fuzz_scalar during codegen and embed the result
|
||||
|
@ -377,8 +377,11 @@ def fuzz_operation_graph(
|
||||
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
# Reset global arg counter for deterministic behavior
|
||||
global _next_arg_id
|
||||
_next_arg_id = 0
|
||||
|
||||
# Global counter for unique node IDs
|
||||
# Global counter for unique node IDs - start from 0 for deterministic behavior
|
||||
node_counter = 0
|
||||
|
||||
# Dictionary to store all nodes: node_id -> OperationNode
|
||||
|
@ -357,55 +357,69 @@ def fuzz_tensor(
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
# Create a local Random instance to avoid interfering with global state
|
||||
local_random = random.Random(seed)
|
||||
|
||||
# Set the torch random seed for reproducibility
|
||||
# Save and restore global torch state to avoid side effects
|
||||
torch_state = torch.get_rng_state()
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Generate random values if not provided
|
||||
if size is None:
|
||||
size = fuzz_tensor_size()
|
||||
# Generate random values if not provided using local random instance
|
||||
old_random_state = random.getstate()
|
||||
try:
|
||||
# Temporarily use local random instance for deterministic generation
|
||||
random.setstate(local_random.getstate())
|
||||
|
||||
if dtype is None:
|
||||
dtype = fuzz_torch_tensor_type("default")
|
||||
if size is None:
|
||||
size = fuzz_tensor_size()
|
||||
|
||||
if stride is None:
|
||||
stride = fuzz_valid_stride(size)
|
||||
if dtype is None:
|
||||
dtype = fuzz_torch_tensor_type("default")
|
||||
|
||||
# Handle empty tensor case
|
||||
if len(size) == 0:
|
||||
return torch.ones((), dtype=dtype), seed
|
||||
if stride is None:
|
||||
stride = fuzz_valid_stride(size)
|
||||
|
||||
# Calculate required storage size for the custom stride
|
||||
required_storage = _compute_storage_size_needed(size, stride)
|
||||
# Handle empty tensor case
|
||||
if len(size) == 0:
|
||||
return torch.ones((), dtype=dtype), seed
|
||||
|
||||
# Create base tensor with sufficient storage
|
||||
if FuzzerConfig.use_real_values:
|
||||
# Use random values based on dtype
|
||||
if dtype.is_floating_point:
|
||||
base_tensor = torch.randn(required_storage, dtype=dtype)
|
||||
elif dtype in [torch.complex64, torch.complex128]:
|
||||
# Create complex tensor with random real and imaginary parts
|
||||
real_part = torch.randn(
|
||||
required_storage,
|
||||
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
|
||||
)
|
||||
imag_part = torch.randn(
|
||||
required_storage,
|
||||
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
|
||||
)
|
||||
base_tensor = torch.complex(real_part, imag_part).to(dtype)
|
||||
elif dtype == torch.bool:
|
||||
base_tensor = torch.randint(0, 2, (required_storage,), dtype=torch.bool)
|
||||
else: # integer types
|
||||
base_tensor = torch.randint(-100, 100, (required_storage,), dtype=dtype)
|
||||
else:
|
||||
# Use zeros (default behavior)
|
||||
base_tensor = torch.ones(required_storage, dtype=dtype)
|
||||
# Calculate required storage size for the custom stride
|
||||
required_storage = _compute_storage_size_needed(size, stride)
|
||||
|
||||
# Create strided tensor view
|
||||
strided_tensor = torch.as_strided(base_tensor, size, stride)
|
||||
# Create base tensor with sufficient storage
|
||||
if FuzzerConfig.use_real_values:
|
||||
# Use random values based on dtype
|
||||
if dtype.is_floating_point:
|
||||
base_tensor = torch.randn(required_storage, dtype=dtype)
|
||||
elif dtype in [torch.complex64, torch.complex128]:
|
||||
# Create complex tensor with random real and imaginary parts
|
||||
real_part = torch.randn(
|
||||
required_storage,
|
||||
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
|
||||
)
|
||||
imag_part = torch.randn(
|
||||
required_storage,
|
||||
dtype=torch.float32 if dtype == torch.complex64 else torch.float64,
|
||||
)
|
||||
base_tensor = torch.complex(real_part, imag_part).to(dtype)
|
||||
elif dtype == torch.bool:
|
||||
base_tensor = torch.randint(0, 2, (required_storage,), dtype=torch.bool)
|
||||
else: # integer types
|
||||
base_tensor = torch.randint(-100, 100, (required_storage,), dtype=dtype)
|
||||
else:
|
||||
# Use zeros (default behavior)
|
||||
base_tensor = torch.ones(required_storage, dtype=dtype)
|
||||
|
||||
return strided_tensor, seed
|
||||
# Create strided tensor view
|
||||
strided_tensor = torch.as_strided(base_tensor, size, stride)
|
||||
|
||||
return strided_tensor, seed
|
||||
finally:
|
||||
# Restore original random state
|
||||
random.setstate(old_random_state)
|
||||
# Restore original torch state
|
||||
torch.set_rng_state(torch_state)
|
||||
|
||||
|
||||
def fuzz_tensor_simple(
|
||||
@ -493,23 +507,49 @@ def fuzz_scalar(spec, seed: Optional[int] = None) -> Union[float, int, bool, com
|
||||
if spec.constant is not None:
|
||||
return spec.constant
|
||||
|
||||
# Set seed for reproducibility if provided
|
||||
# Create a local random instance to avoid interfering with global state
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
local_random = random.Random(seed)
|
||||
# Save and restore global random state
|
||||
old_random_state = random.getstate()
|
||||
try:
|
||||
random.setstate(local_random.getstate())
|
||||
|
||||
# Create a scalar value based on dtype
|
||||
if spec.dtype.is_floating_point:
|
||||
return random.uniform(-10.0, 10.0)
|
||||
elif spec.dtype in [torch.complex64, torch.complex128]:
|
||||
# Only generate complex values if not avoiding complex dtypes
|
||||
if FuzzerConfig.avoid_complex:
|
||||
raise ValueError("Cannot generate complex values with avoid_complex=True")
|
||||
return complex(random.uniform(-10.0, 10.0), random.uniform(-10.0, 10.0))
|
||||
else: # integer or bool
|
||||
if spec.dtype == torch.bool:
|
||||
return random.choice([True, False])
|
||||
else:
|
||||
return random.randint(-10, 10)
|
||||
# Create a scalar value based on dtype
|
||||
if spec.dtype.is_floating_point:
|
||||
return random.uniform(-10.0, 10.0)
|
||||
elif spec.dtype in [torch.complex64, torch.complex128]:
|
||||
# Only generate complex values if not avoiding complex dtypes
|
||||
if FuzzerConfig.avoid_complex:
|
||||
raise ValueError(
|
||||
"Cannot generate complex values with avoid_complex=True"
|
||||
)
|
||||
return complex(random.uniform(-10.0, 10.0), random.uniform(-10.0, 10.0))
|
||||
else: # integer or bool
|
||||
if spec.dtype == torch.bool:
|
||||
return random.choice([True, False])
|
||||
else:
|
||||
return random.randint(-10, 10)
|
||||
finally:
|
||||
# Restore original random state
|
||||
random.setstate(old_random_state)
|
||||
else:
|
||||
# Use current random state when no seed provided
|
||||
# Create a scalar value based on dtype
|
||||
if spec.dtype.is_floating_point:
|
||||
return random.uniform(-10.0, 10.0)
|
||||
elif spec.dtype in [torch.complex64, torch.complex128]:
|
||||
# Only generate complex values if not avoiding complex dtypes
|
||||
if FuzzerConfig.avoid_complex:
|
||||
raise ValueError(
|
||||
"Cannot generate complex values with avoid_complex=True"
|
||||
)
|
||||
return complex(random.uniform(-10.0, 10.0), random.uniform(-10.0, 10.0))
|
||||
else: # integer or bool
|
||||
if spec.dtype == torch.bool:
|
||||
return random.choice([True, False])
|
||||
else:
|
||||
return random.randint(-10, 10)
|
||||
|
||||
|
||||
def specs_compatible(spec1: Spec, spec2: Spec) -> bool:
|
||||
|
107
tools/experimental/dynamic_shapes/torchfuzz/test_determinism.py
Normal file
107
tools/experimental/dynamic_shapes/torchfuzz/test_determinism.py
Normal file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test to verify fuzzer produces deterministic output with same seed."""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_fuzzer_with_seed(seed):
|
||||
"""Run the fuzzer with a specific seed and return the generated code."""
|
||||
cmd = [sys.executable, "fuzzer.py", "--seed", str(seed)]
|
||||
|
||||
# Clear the output directory first
|
||||
torchfuzz_dir = Path("/tmp/torchfuzz")
|
||||
if torchfuzz_dir.exists():
|
||||
for f in torchfuzz_dir.glob("*.py"):
|
||||
f.unlink()
|
||||
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, cwd=Path(__file__).parent
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Fuzzer failed with return code {result.returncode}")
|
||||
print(f"stdout: {result.stdout}")
|
||||
print(f"stderr: {result.stderr}")
|
||||
return None
|
||||
|
||||
# Find the generated Python file in /tmp/torchfuzz/
|
||||
py_files = list(torchfuzz_dir.glob("fuzz_*.py"))
|
||||
if not py_files:
|
||||
print("No Python files generated in /tmp/torchfuzz/")
|
||||
return None
|
||||
|
||||
# Read the content of the generated file
|
||||
with open(py_files[0]) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def test_deterministic_output():
|
||||
"""Test that the fuzzer produces identical output for the same seed."""
|
||||
seed = 115306 # Use the seed mentioned in the user's issue
|
||||
num_runs = 3
|
||||
|
||||
outputs = []
|
||||
|
||||
print(f"Running fuzzer {num_runs} times with seed {seed}...")
|
||||
|
||||
for i in range(num_runs):
|
||||
print(f"Run {i + 1}...")
|
||||
output = run_fuzzer_with_seed(seed)
|
||||
if output is None:
|
||||
print(f"Failed to get output from run {i + 1}")
|
||||
return False
|
||||
outputs.append(output)
|
||||
|
||||
# Compare all outputs
|
||||
first_output = outputs[0]
|
||||
all_identical = all(output == first_output for output in outputs[1:])
|
||||
|
||||
if all_identical:
|
||||
print("✓ SUCCESS: All outputs are identical!")
|
||||
print(f"Generated code length: {len(first_output)} characters")
|
||||
return True
|
||||
else:
|
||||
print("✗ FAILURE: Outputs differ between runs!")
|
||||
|
||||
# Show differences for debugging
|
||||
for i, output in enumerate(outputs[1:], 2):
|
||||
if output != first_output:
|
||||
print(f"\nDifferences between run 1 and run {i}:")
|
||||
|
||||
# Simple line-by-line comparison
|
||||
lines1 = first_output.splitlines()
|
||||
lines2 = output.splitlines()
|
||||
|
||||
min_lines = min(len(lines1), len(lines2))
|
||||
for line_num in range(min_lines):
|
||||
if lines1[line_num] != lines2[line_num]:
|
||||
print(f"Line {line_num + 1}:")
|
||||
print(f" Run 1: {lines1[line_num]}")
|
||||
print(f" Run {i}: {lines2[line_num]}")
|
||||
break
|
||||
|
||||
if len(lines1) != len(lines2):
|
||||
print(f"Different number of lines: {len(lines1)} vs {len(lines2)}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the determinism test."""
|
||||
print("Testing fuzzer determinism...")
|
||||
print("=" * 50)
|
||||
|
||||
success = test_deterministic_output()
|
||||
|
||||
if success:
|
||||
print("\n🎉 Test PASSED: Fuzzer is deterministic!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n❌ Test FAILED: Fuzzer is not deterministic!")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user