mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] Add support for fuzz templates (#163890)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163890 Approved by: https://github.com/pianpwk ghstack dependencies: #163743, #163812
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ebfa3d7d2
commit
19f16a65b4
25
tools/experimental/dynamic_shapes/torchfuzz/checks.py
Normal file
25
tools/experimental/dynamic_shapes/torchfuzz/checks.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Check abstractions for different execution modes and validations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Check(ABC):
|
||||
"""Base class for execution checks."""
|
||||
|
||||
@abstractmethod
|
||||
def codegen(self, args_tuple: str) -> list[str]:
|
||||
"""Generate code lines for this check."""
|
||||
|
||||
|
||||
class EagerVsFullGraphDynamicCompileCheck(Check):
|
||||
"""Standard check that runs eager then fullgraph+dynamic compilation."""
|
||||
|
||||
def codegen(self, args_tuple: str) -> list[str]:
|
||||
return [
|
||||
f"args = {args_tuple}",
|
||||
"result_original = fuzzed_program(*args)",
|
||||
"print('✅ eager success')",
|
||||
"compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)",
|
||||
"result_compiled = compiled_program(*args)",
|
||||
"print('✅ compile success')",
|
||||
]
|
@ -10,8 +10,324 @@ from torchfuzz.tensor_descriptor import format_tensor_descriptor
|
||||
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
||||
|
||||
|
||||
class FuzzTemplate:
|
||||
def __init__(self, supported_ops, check):
|
||||
self.supported_ops = supported_ops
|
||||
self.check = check
|
||||
|
||||
def supported_dtypes(self):
|
||||
"""Return list of supported dtypes for this template."""
|
||||
return [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.bool,
|
||||
]
|
||||
|
||||
|
||||
class DefaultFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
||||
|
||||
super().__init__(
|
||||
supported_ops=[
|
||||
"torch.add",
|
||||
"torch.sub",
|
||||
"torch.mul",
|
||||
"torch.div",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
|
||||
def imports_codegen(self):
|
||||
return [
|
||||
"import torch",
|
||||
]
|
||||
|
||||
def flags_codegen(self):
|
||||
return ["torch._dynamo.config.capture_scalar_outputs = True"]
|
||||
|
||||
def args_codegen(self, arg_operations):
|
||||
"""Generate argument creation code for default template."""
|
||||
code_lines = []
|
||||
|
||||
# Add sentinel tensor that ensures gradient computation
|
||||
code_lines.extend(
|
||||
[
|
||||
"# Sentinel tensor to ensure gradient computation",
|
||||
"sentinel = torch.tensor(1.0, requires_grad=True)",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if arg_operations:
|
||||
for i, (node_id, spec) in enumerate(arg_operations):
|
||||
arg_name = f"arg_{i}"
|
||||
|
||||
if isinstance(spec, ScalarSpec):
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.tensor(torch.randn(()), dtype={dtype_str}).item()"
|
||||
)
|
||||
|
||||
elif isinstance(spec, TensorSpec):
|
||||
size_str = str(spec.size)
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
|
||||
# Calculate storage size needed for the strided tensor
|
||||
if spec.size:
|
||||
# Calculate the maximum index that will be accessed
|
||||
max_offset = 0
|
||||
for dim_size, stride in zip(spec.size, spec.stride):
|
||||
if dim_size > 1:
|
||||
max_offset += (dim_size - 1) * abs(stride)
|
||||
storage_size = max_offset + 1
|
||||
else:
|
||||
storage_size = 1
|
||||
|
||||
stride_str = str(spec.stride)
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
|
||||
)
|
||||
|
||||
return code_lines
|
||||
|
||||
def epilogue_codegen(self):
|
||||
return []
|
||||
|
||||
|
||||
class DTensorFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
||||
|
||||
super().__init__(
|
||||
supported_ops=[
|
||||
"torch.add",
|
||||
"torch.sub",
|
||||
"torch.mul",
|
||||
"torch.div",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
|
||||
def supported_dtypes(self):
|
||||
"""Return list of DTensor-compatible dtypes (no complex types)."""
|
||||
return [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.bool,
|
||||
]
|
||||
|
||||
def imports_codegen(self):
|
||||
return [
|
||||
"import torch",
|
||||
"from torch.distributed.tensor.placement_types import Replicate, Shard",
|
||||
"from torch.testing._internal.distributed.fake_pg import FakeStore",
|
||||
"from torch.distributed.tensor import DTensor",
|
||||
]
|
||||
|
||||
def flags_codegen(self):
|
||||
return [
|
||||
"torch._dynamo.config.capture_scalar_outputs = True",
|
||||
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
|
||||
"torch._inductor.config.emulate_precision_casts = True",
|
||||
]
|
||||
|
||||
def args_codegen(self, arg_operations):
|
||||
"""Generate DTensor argument creation code with proper mesh setup."""
|
||||
code_lines = []
|
||||
|
||||
# Add DTensor setup code first
|
||||
code_lines.extend(
|
||||
[
|
||||
"world_size = 1024",
|
||||
"fake_store = FakeStore()",
|
||||
"torch.distributed.init_process_group(",
|
||||
' "fake", store=fake_store, rank=0, world_size=world_size',
|
||||
")",
|
||||
"",
|
||||
"mesh = torch.distributed.device_mesh.init_device_mesh(",
|
||||
' "cuda",',
|
||||
" (2, 8),",
|
||||
" mesh_dim_names=(",
|
||||
' "dim1", "dim2",',
|
||||
" ),",
|
||||
")",
|
||||
"",
|
||||
"placements = (Replicate(), Replicate())",
|
||||
"",
|
||||
"# Sentinel tensor to ensure gradient computation",
|
||||
"sentinel_local = torch.tensor(1.0, device='cuda', requires_grad=True)",
|
||||
"sentinel = DTensor.from_local(sentinel_local, mesh, placements)",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if arg_operations:
|
||||
for i, (node_id, spec) in enumerate(arg_operations):
|
||||
arg_name = f"arg_{i}"
|
||||
|
||||
if isinstance(spec, ScalarSpec):
|
||||
# For scalars in DTensor, create a 0-dim tensor
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
code_lines.extend(
|
||||
[
|
||||
f"{arg_name}_local = torch.randn((), dtype={dtype_str}, device='cuda', requires_grad=True)",
|
||||
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
||||
]
|
||||
)
|
||||
|
||||
elif isinstance(spec, TensorSpec):
|
||||
size_str = str(spec.size)
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
|
||||
# Handle different dtypes appropriately for DTensor
|
||||
if spec.dtype in [
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
]:
|
||||
# Integer dtypes: use randint and no requires_grad
|
||||
code_lines.extend(
|
||||
[
|
||||
f"{arg_name}_local = torch.randint(1, 10, {size_str}, dtype={dtype_str}, device='cuda')",
|
||||
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
||||
]
|
||||
)
|
||||
elif spec.dtype == torch.bool:
|
||||
# Boolean dtype: use randint and cast to bool
|
||||
code_lines.extend(
|
||||
[
|
||||
f"{arg_name}_local = torch.randint(0, 2, {size_str}, device='cuda').bool()",
|
||||
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Float dtypes: use randn and requires_grad
|
||||
code_lines.extend(
|
||||
[
|
||||
f"{arg_name}_local = torch.randn({size_str}, dtype={dtype_str}, device='cuda', requires_grad=True)",
|
||||
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
||||
]
|
||||
)
|
||||
|
||||
return code_lines
|
||||
|
||||
def epilogue_codegen(self):
|
||||
return ["torch.distributed.destroy_process_group()"]
|
||||
|
||||
|
||||
class UnbackedFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
||||
|
||||
super().__init__(
|
||||
supported_ops=[
|
||||
"torch.ops.aten.item",
|
||||
"torch.ops.aten.nonzero",
|
||||
"torch.ops.aten.masked_select",
|
||||
"torch.ops.aten.unique",
|
||||
# Include basic operations for building up data
|
||||
"torch.add",
|
||||
"torch.sub",
|
||||
"torch.mul",
|
||||
"torch.div",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
|
||||
def supported_dtypes(self):
|
||||
"""Return list of dtypes good for data-dependent operations."""
|
||||
# Focus on dtypes that work well with data-dependent ops and arithmetic
|
||||
# Exclude bool since arithmetic operations don't work with boolean tensors
|
||||
return [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]
|
||||
|
||||
def imports_codegen(self):
|
||||
return [
|
||||
"import torch",
|
||||
]
|
||||
|
||||
def flags_codegen(self):
|
||||
return [
|
||||
"torch._dynamo.config.capture_scalar_outputs = True",
|
||||
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
|
||||
]
|
||||
|
||||
def args_codegen(self, arg_operations):
|
||||
"""Generate argument creation code for unbacked template."""
|
||||
code_lines = []
|
||||
|
||||
# Add sentinel tensor that ensures gradient computation
|
||||
code_lines.extend(
|
||||
[
|
||||
"# Sentinel tensor to ensure gradient computation",
|
||||
"sentinel = torch.tensor(1.0, requires_grad=True)",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if arg_operations:
|
||||
for i, (node_id, spec) in enumerate(arg_operations):
|
||||
arg_name = f"arg_{i}"
|
||||
|
||||
if isinstance(spec, ScalarSpec):
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.tensor(torch.randn(()), dtype={dtype_str}).item()"
|
||||
)
|
||||
|
||||
elif isinstance(spec, TensorSpec):
|
||||
size_str = str(spec.size)
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
|
||||
# For unbacked operations, create tensors with specific patterns
|
||||
# that are likely to produce meaningful results
|
||||
if spec.dtype == torch.bool:
|
||||
# For boolean tensors, create a mix of True/False values
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.randint(0, 2, {size_str}, dtype={dtype_str}) > 0"
|
||||
)
|
||||
elif spec.dtype in [torch.int32, torch.int64]:
|
||||
# For integer tensors, create values that will have some duplicates
|
||||
# and some unique values for operations like unique()
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.randint(0, 3, {size_str}, dtype={dtype_str})"
|
||||
)
|
||||
else:
|
||||
# For float tensors, create values with some zeros and non-zeros
|
||||
code_lines.append(
|
||||
f"{arg_name} = (torch.randn({size_str}) * 2).to({dtype_str})"
|
||||
)
|
||||
# Zero out some values to make nonzero operations meaningful
|
||||
code_lines.append(f"{arg_name}[{arg_name}.abs() < 0.5] = 0")
|
||||
|
||||
return code_lines
|
||||
|
||||
def epilogue_codegen(self):
|
||||
return []
|
||||
|
||||
|
||||
def convert_graph_to_python_code(
|
||||
operation_graph: OperationGraph, seed: Optional[int] = None
|
||||
operation_graph: OperationGraph,
|
||||
seed: Optional[int] = None,
|
||||
template: str = "default",
|
||||
) -> str:
|
||||
"""
|
||||
Convert an operation graph to executable Python code using topological ordering.
|
||||
@ -29,6 +345,14 @@ def convert_graph_to_python_code(
|
||||
String containing the complete Python code that executes the operations
|
||||
"""
|
||||
|
||||
# Instantiate template
|
||||
if template == "dtensor":
|
||||
fuzz_template = DTensorFuzzTemplate()
|
||||
elif template == "unbacked":
|
||||
fuzz_template = UnbackedFuzzTemplate()
|
||||
else:
|
||||
fuzz_template = DefaultFuzzTemplate()
|
||||
|
||||
# Set seed for reproducible code generation
|
||||
if seed is not None:
|
||||
import random
|
||||
@ -100,16 +424,19 @@ def convert_graph_to_python_code(
|
||||
# Generate function signature based on discovered arg operations
|
||||
if arg_operations:
|
||||
arg_names = [f"arg_{i}" for i in range(len(arg_operations))]
|
||||
function_signature = f"def fuzzed_program({', '.join(arg_names)})"
|
||||
function_signature = f"def fuzzed_program({', '.join(arg_names)}, sentinel)"
|
||||
else:
|
||||
function_signature = "def fuzzed_program()"
|
||||
function_signature = "def fuzzed_program(sentinel)"
|
||||
|
||||
# Build the complete code - all imports at the top
|
||||
code_lines = [
|
||||
"import torch",
|
||||
"torch._dynamo.config.capture_scalar_outputs = True",
|
||||
"",
|
||||
]
|
||||
code_lines = []
|
||||
|
||||
# Add template imports
|
||||
code_lines.extend(fuzz_template.imports_codegen())
|
||||
|
||||
# Add template flags
|
||||
code_lines.extend(fuzz_template.flags_codegen())
|
||||
code_lines.append("")
|
||||
|
||||
# Add single seed at the top if seed is provided
|
||||
if seed is not None:
|
||||
@ -121,44 +448,36 @@ def convert_graph_to_python_code(
|
||||
# Add the generated operation code
|
||||
code_lines.extend(generated_code_lines)
|
||||
|
||||
# Add return statement
|
||||
code_lines.extend(
|
||||
[
|
||||
f" return {final_var_name}",
|
||||
"",
|
||||
]
|
||||
)
|
||||
# Add return statement with sentinel multiplication to ensure gradient computation
|
||||
# Handle complex tensors appropriately based on template
|
||||
if template == "dtensor":
|
||||
# For DTensor, avoid .real operation which doesn't work with sharding
|
||||
# Instead use abs() for complex tensors to get a real result
|
||||
code_lines.extend(
|
||||
[
|
||||
" # Ensure gradient computation by multiplying with sentinel",
|
||||
f" result = {final_var_name} * sentinel",
|
||||
" if result.is_complex():",
|
||||
" result = result.abs() # Use abs() instead of .real for DTensor compatibility",
|
||||
" return result",
|
||||
"",
|
||||
]
|
||||
)
|
||||
else:
|
||||
code_lines.extend(
|
||||
[
|
||||
" # Ensure gradient computation by multiplying with sentinel and taking real part",
|
||||
f" result = {final_var_name} * sentinel",
|
||||
" if result.is_complex():",
|
||||
" result = result.real",
|
||||
" return result",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
# Generate argument creation code without individual seeds
|
||||
if arg_operations:
|
||||
for i, (node_id, spec) in enumerate(arg_operations):
|
||||
arg_name = f"arg_{i}"
|
||||
|
||||
if isinstance(spec, ScalarSpec):
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.tensor(torch.randn(()), dtype={dtype_str}).item()"
|
||||
)
|
||||
|
||||
elif isinstance(spec, TensorSpec):
|
||||
size_str = str(spec.size)
|
||||
stride_str = str(spec.stride)
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
|
||||
# Calculate storage size needed for the strided tensor
|
||||
if spec.size:
|
||||
storage_size = 1
|
||||
for dim_size, stride in zip(spec.size, spec.stride):
|
||||
if dim_size > 1:
|
||||
storage_size = max(
|
||||
storage_size, (dim_size - 1) * abs(stride) + 1
|
||||
)
|
||||
else:
|
||||
storage_size = 1
|
||||
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
|
||||
)
|
||||
# Generate argument creation code using template
|
||||
arg_code_lines = fuzz_template.args_codegen(arg_operations)
|
||||
code_lines.extend(arg_code_lines)
|
||||
|
||||
# Generate the final execution with both normal and compiled versions
|
||||
if arg_operations:
|
||||
@ -172,17 +491,15 @@ def convert_graph_to_python_code(
|
||||
else:
|
||||
args_tuple = "()"
|
||||
|
||||
code_lines.extend(
|
||||
[
|
||||
"",
|
||||
f"args = {args_tuple}",
|
||||
"result_original = fuzzed_program(*args)",
|
||||
"print('✅ eager success')",
|
||||
"compiled_program = torch.compile(fuzzed_program, fullgraph=False, dynamic=True)",
|
||||
"result_compiled = compiled_program(*args)",
|
||||
"print('✅ compile success')",
|
||||
]
|
||||
)
|
||||
# Generate execution code using template check
|
||||
check_lines = fuzz_template.check.codegen(f"{args_tuple} + (sentinel,)")
|
||||
code_lines.extend([""] + check_lines)
|
||||
|
||||
# Add template epilogue
|
||||
epilogue_lines = fuzz_template.epilogue_codegen()
|
||||
if epilogue_lines:
|
||||
code_lines.append("")
|
||||
code_lines.extend(epilogue_lines)
|
||||
|
||||
return "\n".join(code_lines)
|
||||
|
||||
|
@ -24,6 +24,7 @@ def fuzz_and_execute(
|
||||
seed: Optional[int] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
log_at_faluire: bool = False,
|
||||
template: str = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Generate a fuzzed operation stack, convert it to Python code, and execute it.
|
||||
@ -111,7 +112,7 @@ def fuzz_and_execute(
|
||||
# Generate target specification first
|
||||
logger.debug("⏱️ Step 1: Generating target spec...")
|
||||
start_time = time.time()
|
||||
target_spec = fuzz_spec()
|
||||
target_spec = fuzz_spec(template)
|
||||
logger.debug(
|
||||
" Completed in %.3fs - %s", time.time() - start_time, target_spec
|
||||
)
|
||||
@ -119,11 +120,13 @@ def fuzz_and_execute(
|
||||
logger.debug("⏱️ Step 2: Generating operation graph...")
|
||||
start_time = time.time()
|
||||
operation_graph = fuzz_operation_graph(
|
||||
target_spec, max_depth=max_depth, seed=seed
|
||||
target_spec, max_depth=max_depth, seed=seed, template=template
|
||||
)
|
||||
logger.debug("⏱️ Step 3: Converting to Python code...")
|
||||
start_time = time.time()
|
||||
python_code = convert_graph_to_python_code(operation_graph, seed=seed)
|
||||
python_code = convert_graph_to_python_code(
|
||||
operation_graph, seed=seed, template=template
|
||||
)
|
||||
logger.debug(
|
||||
" Completed in %.3fs - %d chars",
|
||||
time.time() - start_time,
|
||||
@ -179,6 +182,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--max-depth", type=int, help="Maximum depth for operation stack (1-20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template",
|
||||
choices=["default", "dtensor", "unbacked"],
|
||||
default="default",
|
||||
help="Template to use for code generation (default: default)",
|
||||
)
|
||||
|
||||
# Multi-process fuzzing arguments
|
||||
parser.add_argument(
|
||||
@ -225,7 +234,9 @@ if __name__ == "__main__":
|
||||
if args.seed is not None or args.single:
|
||||
# Single seed execution mode
|
||||
print("Running single fuzz_and_execute...")
|
||||
fuzz_and_execute(seed=args.seed, max_depth=args.max_depth)
|
||||
fuzz_and_execute(
|
||||
seed=args.seed, max_depth=args.max_depth, template=args.template
|
||||
)
|
||||
elif args.start is not None or args.count is not None:
|
||||
# Multi-process fuzzing mode
|
||||
if args.start is None:
|
||||
@ -255,6 +266,7 @@ if __name__ == "__main__":
|
||||
seed_start=args.start,
|
||||
seed_count=args.count,
|
||||
verbose=args.verbose,
|
||||
template=args.template,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {str(e)}")
|
||||
|
@ -9,6 +9,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
try:
|
||||
@ -79,12 +80,13 @@ def is_ignored_output(output: str) -> int:
|
||||
return -1
|
||||
|
||||
|
||||
def run_fuzzer_with_seed(seed: int) -> FuzzerResult:
|
||||
def run_fuzzer_with_seed(seed: int, template: str = "default") -> FuzzerResult:
|
||||
"""
|
||||
Run fuzzer.py with a specific seed.
|
||||
|
||||
Args:
|
||||
seed: The seed value to pass to fuzzer.py
|
||||
template: The template to use for code generation
|
||||
|
||||
Returns:
|
||||
FuzzerResult dataclass instance
|
||||
@ -92,8 +94,16 @@ def run_fuzzer_with_seed(seed: int) -> FuzzerResult:
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run fuzzer.py with the specified seed
|
||||
cmd = [sys.executable, "fuzzer.py", "--single", "--seed", str(seed)]
|
||||
# Run fuzzer.py with the specified seed and template
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"fuzzer.py",
|
||||
"--single",
|
||||
"--seed",
|
||||
str(seed),
|
||||
"--template",
|
||||
template,
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
@ -160,10 +170,11 @@ def handle_result_output(
|
||||
|
||||
|
||||
def run_multi_process_fuzzer(
|
||||
num_processes: int = 2,
|
||||
seed_start: int = 1,
|
||||
seed_count: int = 10,
|
||||
num_processes: Optional[int] = None,
|
||||
seed_start: int = 0,
|
||||
seed_count: int = 100,
|
||||
verbose: bool = False,
|
||||
template: str = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Run the multi-process fuzzer.
|
||||
@ -180,7 +191,9 @@ def run_multi_process_fuzzer(
|
||||
persist_print(
|
||||
f"📊 Processing seeds {seed_start} to {seed_start + seed_count - 1} ({len(seeds)} total)"
|
||||
)
|
||||
persist_print("🔧 Command template: python fuzzer.py --seed {seed}")
|
||||
persist_print(
|
||||
f"🔧 Command template: python fuzzer.py --seed {{seed}} --template {template}"
|
||||
)
|
||||
persist_print("=" * 60)
|
||||
|
||||
start_time = time.time()
|
||||
@ -199,7 +212,7 @@ def run_multi_process_fuzzer(
|
||||
# Submit all seeds to the process pool
|
||||
future_results = []
|
||||
for seed in seeds:
|
||||
future = pool.apply_async(run_fuzzer_with_seed, (seed,))
|
||||
future = pool.apply_async(run_fuzzer_with_seed, (seed, template))
|
||||
future_results.append(future)
|
||||
|
||||
# Set up progress bar
|
||||
|
@ -5,14 +5,34 @@ from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.operators.constant import ConstantOperator
|
||||
from torchfuzz.operators.item import ItemOperator
|
||||
from torchfuzz.operators.registry import get_operator, list_operators, register_operator
|
||||
from torchfuzz.operators.scalar_pointwise import ScalarPointwiseOperator
|
||||
from torchfuzz.operators.tensor_pointwise import TensorPointwiseOperator
|
||||
from torchfuzz.operators.scalar_pointwise import (
|
||||
ScalarAddOperator,
|
||||
ScalarDivOperator,
|
||||
ScalarMulOperator,
|
||||
ScalarPointwiseOperator,
|
||||
ScalarSubOperator,
|
||||
)
|
||||
from torchfuzz.operators.tensor_pointwise import (
|
||||
AddOperator,
|
||||
DivOperator,
|
||||
MulOperator,
|
||||
PointwiseOperator,
|
||||
SubOperator,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Operator",
|
||||
"TensorPointwiseOperator",
|
||||
"PointwiseOperator",
|
||||
"AddOperator",
|
||||
"MulOperator",
|
||||
"SubOperator",
|
||||
"DivOperator",
|
||||
"ScalarPointwiseOperator",
|
||||
"ScalarAddOperator",
|
||||
"ScalarMulOperator",
|
||||
"ScalarSubOperator",
|
||||
"ScalarDivOperator",
|
||||
"ItemOperator",
|
||||
"ConstantOperator",
|
||||
"ArgOperator",
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Arg operator implementation."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec
|
||||
|
||||
@ -10,6 +12,11 @@ class ArgOperator(Operator):
|
||||
def __init__(self):
|
||||
super().__init__("arg")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Arg is not a torch operation, it represents function arguments."""
|
||||
return None
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Arg can produce any type of output."""
|
||||
return True
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Base operator implementation."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from torchfuzz.tensor_fuzzer import Spec
|
||||
|
||||
@ -12,6 +13,17 @@ class Operator(ABC):
|
||||
"""Initialize operator with name."""
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""
|
||||
Return the torch operation name this operator represents.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The torch operation name (e.g., "torch.ops.aten.add", "torch.nonzero").
|
||||
Returns None for non-torch operations like "arg" and "constant".
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Check if this operator can produce the given output spec."""
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Constant operator implementation."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import (
|
||||
fuzz_scalar,
|
||||
@ -15,6 +17,16 @@ class ConstantOperator(Operator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("constant")
|
||||
self.template = "default" # Track template for DTensor compatibility
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Constant is not a torch operation, it generates constant values."""
|
||||
return None
|
||||
|
||||
def set_template(self, template: str):
|
||||
"""Set the template for context-aware code generation."""
|
||||
self.template = template
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Constant can produce any type of output."""
|
||||
@ -77,11 +89,24 @@ class ConstantOperator(Operator):
|
||||
torch.complex128: 0.0,
|
||||
}
|
||||
fill_value = default_values.get(output_spec.dtype, 0)
|
||||
return f"{output_name} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
||||
tensor_creation = (
|
||||
f"torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
||||
)
|
||||
else:
|
||||
# For non-empty tensors, use the first element as fill value
|
||||
fill_value = actual_tensor.flatten()[0].item()
|
||||
return f"{output_name} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
||||
tensor_creation = (
|
||||
f"torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
||||
)
|
||||
|
||||
# For DTensor template, convert to DTensor
|
||||
if self.template == "dtensor":
|
||||
return (
|
||||
f"{output_name}_local = {tensor_creation}.to('cuda')\n"
|
||||
f" {output_name} = DTensor.from_local({output_name}_local, mesh, placements)"
|
||||
)
|
||||
else:
|
||||
return f"{output_name} = {tensor_creation}"
|
||||
|
||||
else:
|
||||
return f"# Unknown output spec type for constant: {type(output_spec)}"
|
||||
|
@ -1,17 +1,24 @@
|
||||
"""Item operator implementation."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
||||
|
||||
|
||||
class ItemOperator(Operator):
|
||||
"""Operator for extracting a scalar from a tensor."""
|
||||
"""Operator for converting 0-d tensor to scalar."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.ops.aten.item")
|
||||
super().__init__("item")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Item is a tensor method, not a direct torch operation."""
|
||||
return None
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Item can only produce scalars."""
|
||||
"""Item produces scalars from 0-d tensors."""
|
||||
return isinstance(output_spec, ScalarSpec)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
|
||||
@ -20,8 +27,8 @@ class ItemOperator(Operator):
|
||||
raise ValueError("ItemOperator can only produce ScalarSpec outputs")
|
||||
|
||||
# Create a tensor spec that can produce a scalar via .item()
|
||||
# Use a 1-D tensor with 1 element
|
||||
tensor_spec = TensorSpec(size=(1,), stride=(1,), dtype=output_spec.dtype)
|
||||
# Use a 0-D tensor (scalar tensor) to ensure .item() works reliably
|
||||
tensor_spec = TensorSpec(size=(), stride=(), dtype=output_spec.dtype)
|
||||
|
||||
return [tensor_spec]
|
||||
|
||||
|
@ -0,0 +1,65 @@
|
||||
"""Masked select operator implementation."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
|
||||
|
||||
class MaskedSelectOperator(Operator):
|
||||
"""Operator for selecting elements from a tensor based on a mask."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("masked_select")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.masked_select"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Masked select produces a 1D tensor with data-dependent size."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Output is always 1D with data-dependent size
|
||||
# Be very restrictive to avoid shape mismatches
|
||||
return (
|
||||
len(output_spec.size) == 1
|
||||
and output_spec.size[0] <= 10 # Reasonable size
|
||||
and output_spec.dtype not in [torch.bool]
|
||||
) # Avoid bool outputs
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||
"""Generate input specs for masked_select operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("MaskedSelectOperator can only produce TensorSpec outputs")
|
||||
|
||||
# Input tensor - can be any shape and type
|
||||
input_tensor_spec = TensorSpec(
|
||||
size=(2, 3), # Fixed size for consistency
|
||||
stride=(3, 1), # Contiguous
|
||||
dtype=output_spec.dtype, # Match output dtype
|
||||
)
|
||||
|
||||
# Mask tensor - must be boolean and broadcastable to input
|
||||
mask_spec = TensorSpec(
|
||||
size=(2, 3), # Same size as input for simplicity
|
||||
stride=(3, 1), # Contiguous
|
||||
dtype=torch.bool,
|
||||
)
|
||||
|
||||
return [input_tensor_spec, mask_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for masked_select operation."""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError("MaskedSelectOperator requires exactly two inputs")
|
||||
|
||||
return (
|
||||
f"{output_name} = torch.masked_select({input_names[0]}, {input_names[1]})"
|
||||
)
|
@ -0,0 +1,58 @@
|
||||
"""Nonzero operator implementation."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
|
||||
|
||||
class NonzeroOperator(Operator):
|
||||
"""Operator for finding nonzero elements in a tensor."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("nonzero")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nonzero"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Nonzero produces a tensor with shape (n_nonzero, n_dims)."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Output shape is (n_nonzero, n_dims) where both are data-dependent
|
||||
# We can only produce integer tensors (indices) and only 2D tensors
|
||||
# Restrict to very specific shapes to avoid shape mismatches
|
||||
return (
|
||||
output_spec.dtype in [torch.int64, torch.long]
|
||||
and len(output_spec.size) == 2
|
||||
and output_spec.size[1] <= 4
|
||||
) # Reasonable input dimensionality
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
|
||||
"""Generate input spec for nonzero operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("NonzeroOperator can only produce TensorSpec outputs")
|
||||
|
||||
# Input can be any tensor type that supports comparison with zero
|
||||
# Use boolean tensors for simplicity to ensure some nonzero elements
|
||||
input_spec = TensorSpec(
|
||||
size=(3, 4), # Fixed size that will have some nonzero elements
|
||||
stride=(4, 1), # Contiguous
|
||||
dtype=torch.bool, # Boolean tensors are good for nonzero testing
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for nonzero operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("NonzeroOperator requires exactly one input")
|
||||
|
||||
return f"{output_name} = torch.nonzero({input_names[0]})"
|
@ -6,8 +6,21 @@ from torchfuzz.operators.arg import ArgOperator
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.operators.constant import ConstantOperator
|
||||
from torchfuzz.operators.item import ItemOperator
|
||||
from torchfuzz.operators.scalar_pointwise import ScalarPointwiseOperator
|
||||
from torchfuzz.operators.tensor_pointwise import TensorPointwiseOperator
|
||||
from torchfuzz.operators.masked_select import MaskedSelectOperator
|
||||
from torchfuzz.operators.nonzero import NonzeroOperator
|
||||
from torchfuzz.operators.scalar_pointwise import (
|
||||
ScalarAddOperator,
|
||||
ScalarDivOperator,
|
||||
ScalarMulOperator,
|
||||
ScalarSubOperator,
|
||||
)
|
||||
from torchfuzz.operators.tensor_pointwise import (
|
||||
AddOperator,
|
||||
DivOperator,
|
||||
MulOperator,
|
||||
SubOperator,
|
||||
)
|
||||
from torchfuzz.operators.unique import UniqueOperator
|
||||
|
||||
|
||||
class OperatorRegistry:
|
||||
@ -20,11 +33,25 @@ class OperatorRegistry:
|
||||
|
||||
def _register_default_operators(self):
|
||||
"""Register the default set of operators."""
|
||||
self.register(TensorPointwiseOperator())
|
||||
self.register(ScalarPointwiseOperator())
|
||||
# Individual tensor pointwise operators (preferred)
|
||||
self.register(AddOperator())
|
||||
self.register(MulOperator())
|
||||
self.register(SubOperator())
|
||||
self.register(DivOperator())
|
||||
|
||||
# Individual scalar pointwise operators (preferred)
|
||||
self.register(ScalarAddOperator())
|
||||
self.register(ScalarMulOperator())
|
||||
self.register(ScalarSubOperator())
|
||||
self.register(ScalarDivOperator())
|
||||
|
||||
self.register(ItemOperator())
|
||||
self.register(ConstantOperator())
|
||||
self.register(ArgOperator())
|
||||
# Data-dependent operators
|
||||
self.register(NonzeroOperator())
|
||||
self.register(MaskedSelectOperator())
|
||||
self.register(UniqueOperator())
|
||||
|
||||
def register(self, operator: Operator):
|
||||
"""Register an operator in the registry."""
|
||||
|
@ -1,17 +1,23 @@
|
||||
"""Scalar pointwise operator implementation."""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec
|
||||
|
||||
|
||||
class ScalarPointwiseOperator(Operator):
|
||||
"""Operator for pointwise operations on scalars (add, mul, sub, div)."""
|
||||
"""Base class for scalar pointwise operations."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("scalar_pointwise")
|
||||
self.operations = ["+", "*", "-", "/"]
|
||||
def __init__(self, name: str, symbol: str):
|
||||
super().__init__(name)
|
||||
self.symbol = symbol
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Scalar operations don't have specific torch ops, they use Python operators."""
|
||||
return None
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Scalar pointwise operations can only produce scalars."""
|
||||
@ -21,7 +27,7 @@ class ScalarPointwiseOperator(Operator):
|
||||
"""Decompose scalar into input scalars for pointwise operation with type promotion."""
|
||||
if not isinstance(output_spec, ScalarSpec):
|
||||
raise ValueError(
|
||||
"ScalarPointwiseOperator can only produce ScalarSpec outputs"
|
||||
f"{self.__class__.__name__} can only produce ScalarSpec outputs"
|
||||
)
|
||||
|
||||
# Use shared type promotion utility
|
||||
@ -37,8 +43,34 @@ class ScalarPointwiseOperator(Operator):
|
||||
) -> str:
|
||||
"""Generate code for scalar pointwise operation."""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError("ScalarPointwiseOperator requires exactly two inputs")
|
||||
raise ValueError(f"{self.__class__.__name__} requires exactly two inputs")
|
||||
|
||||
# Randomly choose an operation
|
||||
op = random.choice(self.operations)
|
||||
return f"{output_name} = {input_names[0]} {op} {input_names[1]}"
|
||||
return f"{output_name} = {input_names[0]} {self.symbol} {input_names[1]}"
|
||||
|
||||
|
||||
class ScalarAddOperator(ScalarPointwiseOperator):
|
||||
"""Operator for scalar addition."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("scalar_add", "+")
|
||||
|
||||
|
||||
class ScalarMulOperator(ScalarPointwiseOperator):
|
||||
"""Operator for scalar multiplication."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("scalar_mul", "*")
|
||||
|
||||
|
||||
class ScalarSubOperator(ScalarPointwiseOperator):
|
||||
"""Operator for scalar subtraction."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("scalar_sub", "-")
|
||||
|
||||
|
||||
class ScalarDivOperator(ScalarPointwiseOperator):
|
||||
"""Operator for scalar division."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("scalar_div", "/")
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Tensor pointwise operator implementation."""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
@ -11,29 +12,18 @@ from torchfuzz.type_promotion import (
|
||||
)
|
||||
|
||||
|
||||
class TensorPointwiseOperator(Operator):
|
||||
"""Operator for element-wise pointwise operations (add, mul, sub, div)."""
|
||||
class PointwiseOperator(Operator):
|
||||
"""Base class for element-wise pointwise operations."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("tensor_pointwise")
|
||||
self.operations = {
|
||||
"add": {
|
||||
"torch_op": "torch.ops.aten.add",
|
||||
"symbol": "+",
|
||||
},
|
||||
"mul": {
|
||||
"torch_op": "torch.ops.aten.mul",
|
||||
"symbol": "*",
|
||||
},
|
||||
"sub": {
|
||||
"torch_op": "torch.ops.aten.sub",
|
||||
"symbol": "-",
|
||||
},
|
||||
"div": {
|
||||
"torch_op": "torch.ops.aten.div",
|
||||
"symbol": "/",
|
||||
},
|
||||
}
|
||||
def __init__(self, name: str, torch_op: str, symbol: str):
|
||||
super().__init__(name)
|
||||
self._torch_op = torch_op
|
||||
self.symbol = symbol
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return self._torch_op
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Tensor pointwise operations can produce tensors but not scalars."""
|
||||
@ -43,7 +33,7 @@ class TensorPointwiseOperator(Operator):
|
||||
"""Decompose tensor into input tensors for pointwise operation with type promotion."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError(
|
||||
"TensorPointwiseOperator can only produce TensorSpec outputs"
|
||||
f"{self.__class__.__name__} can only produce TensorSpec outputs"
|
||||
)
|
||||
|
||||
# Use shared type promotion table
|
||||
@ -79,13 +69,39 @@ class TensorPointwiseOperator(Operator):
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for pointwise operation."""
|
||||
# Randomly choose an operation
|
||||
op_name = random.choice(list(self.operations.keys()))
|
||||
op_info = self.operations[op_name]
|
||||
|
||||
if len(input_names) == 2:
|
||||
return f"{output_name} = {op_info['torch_op']}({input_names[0]}, {input_names[1]})"
|
||||
return (
|
||||
f"{output_name} = {self._torch_op}({input_names[0]}, {input_names[1]})"
|
||||
)
|
||||
else:
|
||||
# Chain operations using symbols for readability
|
||||
expr = f" {op_info['symbol']} ".join(input_names)
|
||||
expr = f" {self.symbol} ".join(input_names)
|
||||
return f"{output_name} = {expr}"
|
||||
|
||||
|
||||
class AddOperator(PointwiseOperator):
|
||||
"""Operator for element-wise addition."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("add", "torch.add", "+")
|
||||
|
||||
|
||||
class MulOperator(PointwiseOperator):
|
||||
"""Operator for element-wise multiplication."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("mul", "torch.mul", "*")
|
||||
|
||||
|
||||
class SubOperator(PointwiseOperator):
|
||||
"""Operator for element-wise subtraction."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("sub", "torch.sub", "-")
|
||||
|
||||
|
||||
class DivOperator(PointwiseOperator):
|
||||
"""Operator for element-wise division."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("div", "torch.div", "/")
|
||||
|
@ -0,0 +1,56 @@
|
||||
"""Unique operator implementation."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
|
||||
|
||||
class UniqueOperator(Operator):
|
||||
"""Operator for finding unique elements in a tensor."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("unique")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.unique"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Unique produces a 1D tensor with data-dependent size."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Output is always 1D with data-dependent size
|
||||
# Be very restrictive to avoid shape mismatches
|
||||
return (
|
||||
len(output_spec.size) == 1
|
||||
and output_spec.size[0] <= 10 # Reasonable size
|
||||
and output_spec.dtype not in [torch.bool]
|
||||
) # Avoid bool outputs
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
|
||||
"""Generate input spec for unique operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("UniqueOperator can only produce TensorSpec outputs")
|
||||
|
||||
# Input can be any tensor - unique will flatten and find unique values
|
||||
input_spec = TensorSpec(
|
||||
size=(2, 3), # Fixed size for consistency
|
||||
stride=(3, 1), # Contiguous
|
||||
dtype=output_spec.dtype, # Match output dtype
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for unique operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("UniqueOperator requires exactly one input")
|
||||
|
||||
return f"{output_name} = torch.unique({input_names[0]})"
|
@ -30,6 +30,64 @@ def _get_cached_operators():
|
||||
return _CACHED_OPERATORS
|
||||
|
||||
|
||||
def _get_template_filtered_operators(template: str = "default"):
|
||||
"""Get operators filtered by template's supported_ops."""
|
||||
# Instantiate template
|
||||
if template == "dtensor":
|
||||
from torchfuzz.codegen import DTensorFuzzTemplate
|
||||
|
||||
fuzz_template = DTensorFuzzTemplate()
|
||||
elif template == "unbacked":
|
||||
from torchfuzz.codegen import UnbackedFuzzTemplate
|
||||
|
||||
fuzz_template = UnbackedFuzzTemplate()
|
||||
else:
|
||||
from torchfuzz.codegen import DefaultFuzzTemplate
|
||||
|
||||
fuzz_template = DefaultFuzzTemplate()
|
||||
|
||||
all_operators = _get_cached_operators()
|
||||
|
||||
# If no supported_ops specified, return all operators
|
||||
if not fuzz_template.supported_ops:
|
||||
return all_operators
|
||||
|
||||
# Filter operators based on supported_ops
|
||||
filtered_ops = {}
|
||||
|
||||
for op_name, operator in all_operators.items():
|
||||
# Always include operations that don't have a specific torch operation
|
||||
# (utility operations like arg, constant, item, scalar ops)
|
||||
torch_op = operator.torch_op_name
|
||||
if torch_op is None:
|
||||
# Set template on operators that support it
|
||||
if hasattr(operator, "set_template"):
|
||||
operator.set_template(template) # type: ignore[attr-defined]
|
||||
filtered_ops[op_name] = operator
|
||||
continue
|
||||
|
||||
# Check if the operator supports any of the template's operations
|
||||
should_include = False
|
||||
for supported_op in fuzz_template.supported_ops:
|
||||
# Direct torch operation matching
|
||||
if torch_op == supported_op:
|
||||
should_include = True
|
||||
break
|
||||
|
||||
# Direct name matching as fallback
|
||||
if supported_op in op_name or op_name in supported_op:
|
||||
should_include = True
|
||||
break
|
||||
|
||||
if should_include:
|
||||
# Set template on operators that support it
|
||||
if hasattr(operator, "set_template"):
|
||||
operator.set_template(template) # type: ignore[attr-defined]
|
||||
filtered_ops[op_name] = operator
|
||||
|
||||
return filtered_ops
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperationNode:
|
||||
"""
|
||||
@ -156,7 +214,7 @@ class OperationGraph:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def fuzz_spec() -> Spec:
|
||||
def fuzz_spec(template: str = "default") -> Spec:
|
||||
"""
|
||||
Generate a random Spec (either TensorSpec or ScalarSpec) using tensor fuzzing functions.
|
||||
|
||||
@ -165,11 +223,14 @@ def fuzz_spec() -> Spec:
|
||||
- fuzz_tensor_size() for random tensor size
|
||||
- fuzz_valid_stride() for random valid strides
|
||||
|
||||
Args:
|
||||
template: Template name to determine supported dtypes
|
||||
|
||||
Returns:
|
||||
Spec: Either a TensorSpec (80% probability) or ScalarSpec (20% probability) with random properties
|
||||
"""
|
||||
# Get random dtype
|
||||
dtype = fuzz_torch_tensor_type()
|
||||
# Get random dtype based on template
|
||||
dtype = fuzz_torch_tensor_type(template)
|
||||
|
||||
# 20% probability of returning ScalarSpec
|
||||
if random.random() < 0.2:
|
||||
@ -182,7 +243,9 @@ def fuzz_spec() -> Spec:
|
||||
return TensorSpec(size=size, stride=stride, dtype=dtype)
|
||||
|
||||
|
||||
def fuzz_op(target_spec: Spec, depth, stack_size) -> tuple[str, list[Spec]]:
|
||||
def fuzz_op(
|
||||
target_spec: Spec, depth, stack_size, template: str = "default"
|
||||
) -> tuple[str, list[Spec]]:
|
||||
"""
|
||||
Given an output specification, returns an operation that can
|
||||
produce a tensor with that layout using the operator class system.
|
||||
@ -197,8 +260,8 @@ def fuzz_op(target_spec: Spec, depth, stack_size) -> tuple[str, list[Spec]]:
|
||||
Tuple of (operation_name, list_of_argument_specs) where each argument spec
|
||||
describes the layout requirements for the operation's inputs
|
||||
"""
|
||||
# Get all available operators (cached)
|
||||
available_operators = _get_cached_operators()
|
||||
# Get template-filtered operators
|
||||
available_operators = _get_template_filtered_operators(template)
|
||||
|
||||
# Filter operators that can produce the target spec
|
||||
compatible_ops = []
|
||||
@ -274,6 +337,7 @@ def fuzz_operation_graph(
|
||||
target_spec: Spec,
|
||||
max_depth: int = 7,
|
||||
seed: Optional[int] = None,
|
||||
template: str = "default",
|
||||
) -> OperationGraph:
|
||||
"""
|
||||
Generate a graph of operations that produces the target specification.
|
||||
@ -285,6 +349,7 @@ def fuzz_operation_graph(
|
||||
target_spec: The desired output specification (TensorSpec or ScalarSpec)
|
||||
max_depth: Maximum depth of operations. At depth 0, only leaf operations (constant, arg) are used.
|
||||
seed: Random seed for reproducible generation. If None, uses current random state.
|
||||
template: Template name to determine configuration
|
||||
|
||||
Returns:
|
||||
OperationGraph with nodes organized in a DAG structure
|
||||
@ -310,7 +375,7 @@ def fuzz_operation_graph(
|
||||
nonlocal node_counter
|
||||
|
||||
# Generate new operation
|
||||
op_name, input_specs = fuzz_op(spec, depth, stack_size)
|
||||
op_name, input_specs = fuzz_op(spec, depth, stack_size, template)
|
||||
|
||||
# Create unique node ID
|
||||
node_id = f"node_{node_counter}"
|
||||
|
@ -34,36 +34,35 @@ class ScalarSpec(NamedTuple):
|
||||
Spec = Union[TensorSpec, ScalarSpec]
|
||||
|
||||
|
||||
def fuzz_torch_tensor_type() -> torch.dtype:
|
||||
def fuzz_torch_tensor_type(template: str = "default") -> torch.dtype:
|
||||
"""
|
||||
Fuzzes PyTorch tensor data types by randomly selecting and returning different dtypes.
|
||||
|
||||
Args:
|
||||
template: Template name to determine supported dtypes
|
||||
|
||||
Returns:
|
||||
torch.dtype: A randomly selected PyTorch tensor data type
|
||||
torch.dtype: A randomly selected PyTorch tensor data type based on template constraints
|
||||
"""
|
||||
|
||||
# Available PyTorch tensor data types (excluding unsigned types to avoid compatibility issues)
|
||||
tensor_dtypes: list[torch.dtype] = [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.bool,
|
||||
torch.complex64,
|
||||
torch.complex128,
|
||||
]
|
||||
# Get template-specific dtypes
|
||||
if template == "dtensor":
|
||||
# Import here to avoid circular imports
|
||||
from torchfuzz.codegen import DTensorFuzzTemplate
|
||||
|
||||
# Filter out complex dtypes if avoid_complex is enabled
|
||||
if FuzzerConfig.avoid_complex:
|
||||
tensor_dtypes = [
|
||||
dtype
|
||||
for dtype in tensor_dtypes
|
||||
if dtype not in [torch.complex64, torch.complex128]
|
||||
]
|
||||
fuzz_template = DTensorFuzzTemplate()
|
||||
tensor_dtypes = fuzz_template.supported_dtypes()
|
||||
elif template == "unbacked":
|
||||
# Import here to avoid circular imports
|
||||
from torchfuzz.codegen import UnbackedFuzzTemplate
|
||||
|
||||
fuzz_template = UnbackedFuzzTemplate()
|
||||
tensor_dtypes = fuzz_template.supported_dtypes()
|
||||
else:
|
||||
from torchfuzz.codegen import DefaultFuzzTemplate
|
||||
|
||||
fuzz_template = DefaultFuzzTemplate()
|
||||
tensor_dtypes = fuzz_template.supported_dtypes()
|
||||
|
||||
# Randomly select and return a data type
|
||||
return random.choice(tensor_dtypes)
|
||||
@ -367,7 +366,7 @@ def fuzz_tensor(
|
||||
size = fuzz_tensor_size()
|
||||
|
||||
if dtype is None:
|
||||
dtype = fuzz_torch_tensor_type()
|
||||
dtype = fuzz_torch_tensor_type("default")
|
||||
|
||||
if stride is None:
|
||||
stride = fuzz_valid_stride(size)
|
||||
@ -445,7 +444,7 @@ def fuzz_non_contiguous_dense_tensor(
|
||||
torch.Tensor: A non-contiguous but dense tensor
|
||||
"""
|
||||
if dtype is None:
|
||||
dtype = fuzz_torch_tensor_type()
|
||||
dtype = fuzz_torch_tensor_type("default")
|
||||
|
||||
if size is None:
|
||||
size = fuzz_tensor_size()
|
||||
|
Reference in New Issue
Block a user