Files
pytorch/tools/experimental/torchfuzz/codegen.py

720 lines
26 KiB
Python

# mypy: ignore-errors
import os
from typing import Optional
import torch
from torchfuzz.operators import get_operator
from torchfuzz.ops_fuzzer import OperationGraph
from torchfuzz.tensor_descriptor import format_tensor_descriptor
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
class FuzzTemplate:
def __init__(self, supported_ops, check):
self.supported_ops = supported_ops
self.check = check
def supported_dtypes(self):
"""Return list of supported dtypes for this template."""
return [
torch.float32,
torch.float64,
torch.float16,
torch.bfloat16,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
def spec_distribution(self):
"""
Define the distribution for generating random Specs.
Returns:
Dict with keys:
- 'tensor_prob': Probability of generating TensorSpec (0.0 to 1.0)
- 'scalar_prob': Probability of generating ScalarSpec (0.0 to 1.0)
- 'allow_tensors': Whether TensorSpec generation is allowed (boolean)
- 'allow_scalars': Whether ScalarSpec generation is allowed (boolean)
"""
return {
"tensor_prob": 0.8,
"scalar_prob": 0.2,
"allow_tensors": True,
"allow_scalars": True,
}
def fuzz_spec_custom(self):
"""
Generate a random Spec based on this template's distribution preferences.
Returns:
Spec: Either a TensorSpec or ScalarSpec according to template's distribution
"""
import random
from torchfuzz.tensor_fuzzer import fuzz_torch_tensor_type
# Get template's distribution configuration
distribution = self.spec_distribution()
# Get random dtype based on template
dtype = fuzz_torch_tensor_type("default")
# Validate distribution configuration
allow_tensors = distribution.get("allow_tensors", True)
allow_scalars = distribution.get("allow_scalars", True)
if not allow_tensors and not allow_scalars:
raise ValueError("Template must allow at least one of tensors or scalars")
# Determine which type to generate
if not allow_scalars:
# Only tensors allowed
return self._generate_tensor_spec(dtype)
elif not allow_tensors:
# Only scalars allowed
return self._generate_scalar_spec(dtype)
else:
# Both allowed, use probability distribution
tensor_prob = distribution.get("tensor_prob", 0.8)
if random.random() < tensor_prob:
return self._generate_tensor_spec(dtype)
else:
return self._generate_scalar_spec(dtype)
def _generate_tensor_spec(self, dtype):
"""Generate a TensorSpec with the given dtype."""
from torchfuzz.tensor_fuzzer import (
fuzz_tensor_size,
fuzz_valid_stride,
TensorSpec,
)
size = fuzz_tensor_size()
stride = fuzz_valid_stride(size)
return TensorSpec(size=size, stride=stride, dtype=dtype)
def _generate_scalar_spec(self, dtype):
"""Generate a ScalarSpec with the given dtype."""
from torchfuzz.tensor_fuzzer import ScalarSpec
return ScalarSpec(dtype=dtype)
def args_codegen(self, arg_operations):
"""Generate argument creation code for default template."""
code_lines = []
# Add sentinel tensor that ensures gradient computation
code_lines.extend(
[
"# Sentinel tensor to ensure gradient computation",
"sentinel = torch.tensor(1.0, requires_grad=True)",
"",
]
)
if arg_operations:
for i, (node_id, spec) in enumerate(arg_operations):
arg_name = f"arg_{i}"
if isinstance(spec, ScalarSpec):
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
if spec.dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
# For integer scalars, use randint to avoid always getting 0
code_lines.append(
f"{arg_name} = int(torch.randint(5, 30, ()).item())"
)
elif spec.dtype == torch.bool:
# For boolean scalars, use randint and cast to bool
code_lines.append(
f"{arg_name} = bool(torch.randint(0, 2, ()).item())"
)
else:
# For float scalars, use randn
code_lines.append(
f"{arg_name} = float(torch.randn((), dtype={dtype_str}).item())"
)
elif isinstance(spec, TensorSpec):
size_str = str(spec.size)
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
# Calculate storage size needed for the strided tensor
if spec.size:
# Calculate the maximum index that will be accessed
max_offset = 0
for dim_size, stride in zip(spec.size, spec.stride):
if dim_size > 1:
max_offset += (dim_size - 1) * abs(stride)
storage_size = max_offset + 1
else:
storage_size = 1
stride_str = str(spec.stride)
# Special handling for integer tensors which might be used as indices
if spec.dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
# For integer tensors, generate valid indices with headroom for arithmetic
# Use smaller range [5, 30] to allow for multiplication and other operations
# This prevents indices from becoming too large after arithmetic
min_val = (
5 # Minimum to avoid negative results after subtraction
)
max_val = (
30 # Maximum to avoid out-of-bounds after multiplication
)
code_lines.append(
f"{arg_name} = torch.as_strided(torch.randint({min_val}, {max_val}, ({storage_size},)).to({dtype_str}), {size_str}, {stride_str})"
)
elif spec.dtype == torch.bool:
# For boolean tensors, use randint to generate True/False values
# Using randn().to(bool) would yield almost all True due to non-zero floats
code_lines.append(
f"{arg_name} = torch.as_strided(torch.randint(0, 2, ({storage_size},), dtype=torch.int8).bool(), {size_str}, {stride_str})"
)
else:
code_lines.append(
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
)
return code_lines
class DefaultFuzzTemplate(FuzzTemplate):
def __init__(self):
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
super().__init__(
supported_ops=[
# Basic arithmetic operations
"torch.add",
"torch.sub",
"torch.mul",
"torch.div",
# Tensor shape operations
"torch.Tensor.view",
"torch.reshape",
"torch.flatten",
"torch.squeeze",
"torch.unsqueeze",
# Matrix operations
"torch.mm",
"torch.addmm",
"torch.bmm",
"torch.matmul",
# Neural network operations
"torch.nn.functional.embedding",
"torch.nn.functional.linear",
# Activation functions
"torch.nn.functional.relu",
"torch.nn.functional.leaky_relu",
"torch.nn.functional.elu",
"torch.nn.functional.gelu",
"torch.nn.functional.silu",
"torch.sigmoid",
"torch.tanh",
"torch.nn.functional.softmax",
# Normalization layers
"torch.nn.functional.layer_norm",
"torch.nn.functional.rms_norm",
"torch.nn.functional.batch_norm",
"torch.nn.functional.group_norm",
# Regularization
"torch.nn.functional.dropout",
],
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
)
def spec_distribution(self):
"""Default template: tensor-only (no scalars)."""
return {
"tensor_prob": 1.0,
"scalar_prob": 0.0,
"allow_tensors": True,
"allow_scalars": False,
}
def imports_codegen(self):
return [
"import torch",
]
def flags_codegen(self):
return ["torch._dynamo.config.capture_scalar_outputs = True"]
def epilogue_codegen(self):
return []
class DTensorFuzzTemplate(FuzzTemplate):
def __init__(self):
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
super().__init__(
supported_ops=[
"torch.add",
"torch.sub",
"torch.mul",
"torch.div",
"torch.mm",
"torch.addmm",
"torch.bmm",
"torch.matmul",
],
check=EagerVsFullGraphDynamicCompileCheck(),
)
def supported_dtypes(self):
"""Return list of DTensor-compatible dtypes (no complex types)."""
return [
torch.float32,
torch.float64,
torch.float16,
torch.bfloat16,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
def spec_distribution(self):
"""DTensor template: tensor-only (no scalars)."""
return {
"tensor_prob": 1.0,
"scalar_prob": 0.0,
"allow_tensors": True,
"allow_scalars": False,
}
def imports_codegen(self):
return [
"import torch",
"from torch.distributed.tensor.placement_types import Replicate, Shard",
"from torch.testing._internal.distributed.fake_pg import FakeStore",
"from torch.distributed.tensor import DTensor",
]
def flags_codegen(self):
return [
"torch._dynamo.config.capture_scalar_outputs = True",
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
"torch._inductor.config.emulate_precision_casts = True",
]
def args_codegen(self, arg_operations):
"""Generate DTensor argument creation code with proper mesh setup."""
code_lines = []
# Add DTensor setup code first
code_lines.extend(
[
"world_size = 1024",
"fake_store = FakeStore()",
"torch.distributed.init_process_group(",
' "fake", store=fake_store, rank=0, world_size=world_size',
")",
"",
"mesh = torch.distributed.device_mesh.init_device_mesh(",
' "cuda",',
" (2, 8),",
" mesh_dim_names=(",
' "dim1", "dim2",',
" ),",
")",
"",
"placements = (Replicate(), Replicate())",
"",
"# Sentinel tensor to ensure gradient computation",
"sentinel_local = torch.tensor(1.0, device='cuda', requires_grad=True)",
"sentinel = DTensor.from_local(sentinel_local, mesh, placements)",
"",
]
)
if arg_operations:
for i, (node_id, spec) in enumerate(arg_operations):
arg_name = f"arg_{i}"
if isinstance(spec, ScalarSpec):
# For scalars in DTensor, create a 0-dim tensor
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
code_lines.extend(
[
f"{arg_name}_local = torch.randn((), dtype={dtype_str}, device='cuda', requires_grad=True)",
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
]
)
elif isinstance(spec, TensorSpec):
size_str = str(spec.size)
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
# Handle different dtypes appropriately for DTensor
if spec.dtype in [
torch.int32,
torch.int64,
torch.int8,
torch.int16,
]:
# Integer dtypes: use randint and no requires_grad
code_lines.extend(
[
f"{arg_name}_local = torch.randint(1, 10, {size_str}, dtype={dtype_str}, device='cuda')",
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
]
)
elif spec.dtype == torch.bool:
# Boolean dtype: use randint and cast to bool
code_lines.extend(
[
f"{arg_name}_local = torch.randint(0, 2, {size_str}, device='cuda').bool()",
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
]
)
else:
# Float dtypes: use randn and requires_grad
code_lines.extend(
[
f"{arg_name}_local = torch.randn({size_str}, dtype={dtype_str}, device='cuda', requires_grad=True)",
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
]
)
return code_lines
def epilogue_codegen(self):
return ["torch.distributed.destroy_process_group()"]
class UnbackedFuzzTemplate(FuzzTemplate):
def __init__(self):
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
super().__init__(
supported_ops=[
"torch.ops.aten.item",
"torch.ops.aten.nonzero",
"torch.ops.aten.masked_select",
"torch.ops.aten.unique",
# Basic arithmetic operations
"torch.add",
"torch.sub",
"torch.mul",
"torch.div",
# Tensor shape operations
"torch.Tensor.view",
"torch.reshape",
"torch.flatten",
"torch.squeeze",
"torch.unsqueeze",
# Matrix operations
"torch.mm",
"torch.addmm",
"torch.bmm",
"torch.matmul",
# Neural network operations
"torch.nn.functional.embedding",
"torch.nn.functional.linear",
# Activation functions
"torch.nn.functional.relu",
"torch.nn.functional.leaky_relu",
"torch.nn.functional.elu",
"torch.nn.functional.gelu",
"torch.nn.functional.silu",
"torch.sigmoid",
"torch.tanh",
"torch.nn.functional.softmax",
# Normalization layers
"torch.nn.functional.layer_norm",
"torch.nn.functional.rms_norm",
"torch.nn.functional.batch_norm",
"torch.nn.functional.group_norm",
# Regularization
"torch.nn.functional.dropout",
],
check=EagerVsFullGraphDynamicCompileCheck(),
)
def supported_dtypes(self):
"""Return list of dtypes good for data-dependent operations."""
# Focus on dtypes that work well with data-dependent ops and arithmetic
# Exclude bool since arithmetic operations don't work with boolean tensors
return [
torch.float32,
torch.float64,
torch.int32,
torch.int64,
]
def spec_distribution(self):
"""Unbacked template: 50% tensors, 50% scalars."""
return {
"tensor_prob": 0.5,
"scalar_prob": 0.5,
"allow_tensors": True,
"allow_scalars": True,
}
def imports_codegen(self):
return [
"import torch",
]
def flags_codegen(self):
return [
"torch._dynamo.config.capture_scalar_outputs = True",
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
]
def epilogue_codegen(self):
return []
def convert_graph_to_python_code(
operation_graph: OperationGraph,
seed: Optional[int] = None,
template: str = "default",
) -> str:
"""
Convert an operation graph to executable Python code using topological ordering.
The graph-based approach generates code by:
1. Getting the topological order of nodes (dependencies before dependents)
2. Generating code for each node in that order
3. Properly handling input dependencies through node connections
Args:
operation_graph: OperationGraph instance containing the operation DAG
seed: Random seed for reproducible code generation. If None, uses current random state.
Returns:
String containing the complete Python code that executes the operations
"""
# Instantiate template
if template == "dtensor":
fuzz_template = DTensorFuzzTemplate()
elif template == "unbacked":
fuzz_template = UnbackedFuzzTemplate()
else:
fuzz_template = DefaultFuzzTemplate()
# Set seed for reproducible code generation
if seed is not None:
import random
random.seed(seed + 1000) # Offset to avoid conflicts with graph generation
torch.manual_seed(seed + 1000)
if not operation_graph.nodes:
raise ValueError("Empty operation graph")
# Get topological order - this ensures dependencies are processed before dependents
topo_order = operation_graph.get_topological_order()
# Track generated variables and arg operations
generated_code_lines = []
node_variables: dict[str, tuple[str, Spec]] = {} # Maps node_id to (var_name, spec)
arg_operations: list[
tuple[str, Spec]
] = [] # List of (node_id, spec) for arg operations
# Process nodes in topological order
for node_id in topo_order:
node = operation_graph.nodes[node_id]
op_name = node.op_name
output_spec = node.output_spec
# Generate output variable name
output_var_name = f"var_{node_id}"
# Generate input variable names from input nodes
input_var_names = []
for input_node_id in node.input_nodes:
if input_node_id in node_variables:
input_var_name, _ = node_variables[input_node_id]
input_var_names.append(input_var_name)
else:
raise ValueError(
f"Node {node_id} depends on {input_node_id}, but {input_node_id} "
f"was not processed yet. Topological order may be incorrect."
)
# Handle different operation types
if op_name == "arg" or op_name.startswith("arg_"):
# Track arg operations for later function signature generation
arg_operations.append((node_id, output_spec))
arg_name = f"arg_{len(arg_operations) - 1}"
# Add tensor descriptor comment for arg operations too
descriptor_comment = f"# {format_tensor_descriptor(output_spec)}"
operation_lines = [f"{output_var_name} = {arg_name} " + descriptor_comment]
else:
# Generate operation execution code
operation_lines = generate_simple_operation_code(
output_var_name, input_var_names, op_name, output_spec
)
# Add proper indentation for function body
generated_code_lines.extend([" " + line for line in operation_lines])
# Track this node's variable
node_variables[node_id] = (output_var_name, output_spec)
# The final result comes from the root node
root_node_id = operation_graph.root_node_id
if root_node_id not in node_variables:
raise ValueError(f"Root node {root_node_id} was not processed")
final_var_name, _ = node_variables[root_node_id]
# Generate function signature based on discovered arg operations
if arg_operations:
arg_names = [f"arg_{i}" for i in range(len(arg_operations))]
function_signature = f"def fuzzed_program({', '.join(arg_names)}, sentinel)"
else:
function_signature = "def fuzzed_program(sentinel)"
# Build the complete code - all imports at the top
code_lines = []
# Add template imports
code_lines.extend(fuzz_template.imports_codegen())
# Add template flags
code_lines.extend(fuzz_template.flags_codegen())
code_lines.append("")
# Add single seed at the top if seed is provided
if seed is not None:
code_lines.append(f"torch.manual_seed({seed})")
code_lines.append("")
code_lines.append(function_signature + ":")
# Add the generated operation code
code_lines.extend(generated_code_lines)
# Add return statement with sentinel multiplication to ensure gradient computation
# Handle complex tensors appropriately based on template
if template == "dtensor":
# For DTensor, avoid .real operation which doesn't work with sharding
# Instead use abs() for complex tensors to get a real result
code_lines.extend(
[
" # Ensure gradient computation by multiplying with sentinel",
f" result = {final_var_name} * sentinel",
" if result.is_complex():",
" result = result.abs() # Use abs() instead of .real for DTensor compatibility",
" return result",
"",
]
)
else:
code_lines.extend(
[
" # Ensure gradient computation by multiplying with sentinel and taking real part",
f" result = {final_var_name} * sentinel",
" if result.is_complex():",
" result = result.real",
" return result",
"",
]
)
# Generate argument creation code using template
arg_code_lines = fuzz_template.args_codegen(arg_operations)
code_lines.extend(arg_code_lines)
# Generate the final execution with both normal and compiled versions
if arg_operations:
arg_names = [f"arg_{i}" for i in range(len(arg_operations))]
if len(arg_names) == 1:
args_tuple = (
f"({arg_names[0]},)" # Single element tuple needs trailing comma
)
else:
args_tuple = f"({', '.join(arg_names)})"
else:
args_tuple = "()"
# Generate execution code using template check
check_lines = fuzz_template.check.codegen(f"{args_tuple} + (sentinel,)")
code_lines.extend([""] + check_lines)
# Add template epilogue
epilogue_lines = fuzz_template.epilogue_codegen()
if epilogue_lines:
code_lines.append("")
code_lines.extend(epilogue_lines)
return "\n".join(code_lines)
def generate_simple_operation_code(
output_var: str,
input_vars: list,
op_name: str,
output_spec,
) -> list:
"""
Generate code lines for executing a single operation using class-based operators.
Args:
output_var: Name of the output variable
input_vars: List of input variable names
op_name: Name of the operation
output_spec: Output specification for the operation
"""
# Try to get the operator from the registry
operator = get_operator(op_name)
if operator is not None:
# Use the class-based operator to generate code
code_line = operator.codegen(output_var, input_vars, output_spec)
# Add tensor descriptor comment
descriptor_comment = f"# {format_tensor_descriptor(output_spec)}"
return [code_line + " " + descriptor_comment]
else:
# Fallback for unknown operations
return [f"# Unknown operation: {op_name}"]
def create_program_file(python_code: str) -> str:
"""
Create a temporary Python file from the generated code.
Args:
python_code: String containing Python code to write
Returns:
Path to the created temporary file
"""
import hashlib
# Generate a deterministic filename based on code content hash
code_hash = hashlib.md5(python_code.encode()).hexdigest()[:8] # noqa: S324
tmp_dir = "/tmp/torchfuzz"
os.makedirs(tmp_dir, exist_ok=True)
generated_file_path = os.path.join(tmp_dir, f"fuzz_{code_hash}.py")
# Write the generated code to the specified file
with open(generated_file_path, "w") as f:
f.write(python_code)
return generated_file_path