mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164715 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687, #164688, #164693, #164694
720 lines
26 KiB
Python
720 lines
26 KiB
Python
# mypy: ignore-errors
|
|
import os
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from torchfuzz.operators import get_operator
|
|
from torchfuzz.ops_fuzzer import OperationGraph
|
|
from torchfuzz.tensor_descriptor import format_tensor_descriptor
|
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
|
|
|
|
|
class FuzzTemplate:
|
|
def __init__(self, supported_ops, check):
|
|
self.supported_ops = supported_ops
|
|
self.check = check
|
|
|
|
def supported_dtypes(self):
|
|
"""Return list of supported dtypes for this template."""
|
|
return [
|
|
torch.float32,
|
|
torch.float64,
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.bool,
|
|
]
|
|
|
|
def spec_distribution(self):
|
|
"""
|
|
Define the distribution for generating random Specs.
|
|
|
|
Returns:
|
|
Dict with keys:
|
|
- 'tensor_prob': Probability of generating TensorSpec (0.0 to 1.0)
|
|
- 'scalar_prob': Probability of generating ScalarSpec (0.0 to 1.0)
|
|
- 'allow_tensors': Whether TensorSpec generation is allowed (boolean)
|
|
- 'allow_scalars': Whether ScalarSpec generation is allowed (boolean)
|
|
"""
|
|
return {
|
|
"tensor_prob": 0.8,
|
|
"scalar_prob": 0.2,
|
|
"allow_tensors": True,
|
|
"allow_scalars": True,
|
|
}
|
|
|
|
def fuzz_spec_custom(self):
|
|
"""
|
|
Generate a random Spec based on this template's distribution preferences.
|
|
|
|
Returns:
|
|
Spec: Either a TensorSpec or ScalarSpec according to template's distribution
|
|
"""
|
|
import random
|
|
|
|
from torchfuzz.tensor_fuzzer import fuzz_torch_tensor_type
|
|
|
|
# Get template's distribution configuration
|
|
distribution = self.spec_distribution()
|
|
|
|
# Get random dtype based on template
|
|
dtype = fuzz_torch_tensor_type("default")
|
|
|
|
# Validate distribution configuration
|
|
allow_tensors = distribution.get("allow_tensors", True)
|
|
allow_scalars = distribution.get("allow_scalars", True)
|
|
|
|
if not allow_tensors and not allow_scalars:
|
|
raise ValueError("Template must allow at least one of tensors or scalars")
|
|
|
|
# Determine which type to generate
|
|
if not allow_scalars:
|
|
# Only tensors allowed
|
|
return self._generate_tensor_spec(dtype)
|
|
elif not allow_tensors:
|
|
# Only scalars allowed
|
|
return self._generate_scalar_spec(dtype)
|
|
else:
|
|
# Both allowed, use probability distribution
|
|
tensor_prob = distribution.get("tensor_prob", 0.8)
|
|
if random.random() < tensor_prob:
|
|
return self._generate_tensor_spec(dtype)
|
|
else:
|
|
return self._generate_scalar_spec(dtype)
|
|
|
|
def _generate_tensor_spec(self, dtype):
|
|
"""Generate a TensorSpec with the given dtype."""
|
|
from torchfuzz.tensor_fuzzer import (
|
|
fuzz_tensor_size,
|
|
fuzz_valid_stride,
|
|
TensorSpec,
|
|
)
|
|
|
|
size = fuzz_tensor_size()
|
|
stride = fuzz_valid_stride(size)
|
|
return TensorSpec(size=size, stride=stride, dtype=dtype)
|
|
|
|
def _generate_scalar_spec(self, dtype):
|
|
"""Generate a ScalarSpec with the given dtype."""
|
|
from torchfuzz.tensor_fuzzer import ScalarSpec
|
|
|
|
return ScalarSpec(dtype=dtype)
|
|
|
|
def args_codegen(self, arg_operations):
|
|
"""Generate argument creation code for default template."""
|
|
code_lines = []
|
|
|
|
# Add sentinel tensor that ensures gradient computation
|
|
code_lines.extend(
|
|
[
|
|
"# Sentinel tensor to ensure gradient computation",
|
|
"sentinel = torch.tensor(1.0, requires_grad=True)",
|
|
"",
|
|
]
|
|
)
|
|
|
|
if arg_operations:
|
|
for i, (node_id, spec) in enumerate(arg_operations):
|
|
arg_name = f"arg_{i}"
|
|
|
|
if isinstance(spec, ScalarSpec):
|
|
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
|
if spec.dtype in [
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
]:
|
|
# For integer scalars, use randint to avoid always getting 0
|
|
code_lines.append(
|
|
f"{arg_name} = int(torch.randint(5, 30, ()).item())"
|
|
)
|
|
elif spec.dtype == torch.bool:
|
|
# For boolean scalars, use randint and cast to bool
|
|
code_lines.append(
|
|
f"{arg_name} = bool(torch.randint(0, 2, ()).item())"
|
|
)
|
|
else:
|
|
# For float scalars, use randn
|
|
code_lines.append(
|
|
f"{arg_name} = float(torch.randn((), dtype={dtype_str}).item())"
|
|
)
|
|
|
|
elif isinstance(spec, TensorSpec):
|
|
size_str = str(spec.size)
|
|
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
|
|
|
# Calculate storage size needed for the strided tensor
|
|
if spec.size:
|
|
# Calculate the maximum index that will be accessed
|
|
max_offset = 0
|
|
for dim_size, stride in zip(spec.size, spec.stride):
|
|
if dim_size > 1:
|
|
max_offset += (dim_size - 1) * abs(stride)
|
|
storage_size = max_offset + 1
|
|
else:
|
|
storage_size = 1
|
|
|
|
stride_str = str(spec.stride)
|
|
|
|
# Special handling for integer tensors which might be used as indices
|
|
if spec.dtype in [
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
]:
|
|
# For integer tensors, generate valid indices with headroom for arithmetic
|
|
# Use smaller range [5, 30] to allow for multiplication and other operations
|
|
# This prevents indices from becoming too large after arithmetic
|
|
min_val = (
|
|
5 # Minimum to avoid negative results after subtraction
|
|
)
|
|
max_val = (
|
|
30 # Maximum to avoid out-of-bounds after multiplication
|
|
)
|
|
code_lines.append(
|
|
f"{arg_name} = torch.as_strided(torch.randint({min_val}, {max_val}, ({storage_size},)).to({dtype_str}), {size_str}, {stride_str})"
|
|
)
|
|
elif spec.dtype == torch.bool:
|
|
# For boolean tensors, use randint to generate True/False values
|
|
# Using randn().to(bool) would yield almost all True due to non-zero floats
|
|
code_lines.append(
|
|
f"{arg_name} = torch.as_strided(torch.randint(0, 2, ({storage_size},), dtype=torch.int8).bool(), {size_str}, {stride_str})"
|
|
)
|
|
else:
|
|
code_lines.append(
|
|
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
|
|
)
|
|
|
|
return code_lines
|
|
|
|
|
|
class DefaultFuzzTemplate(FuzzTemplate):
|
|
def __init__(self):
|
|
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
|
|
|
|
super().__init__(
|
|
supported_ops=[
|
|
# Basic arithmetic operations
|
|
"torch.add",
|
|
"torch.sub",
|
|
"torch.mul",
|
|
"torch.div",
|
|
# Tensor shape operations
|
|
"torch.Tensor.view",
|
|
"torch.reshape",
|
|
"torch.flatten",
|
|
"torch.squeeze",
|
|
"torch.unsqueeze",
|
|
# Matrix operations
|
|
"torch.mm",
|
|
"torch.addmm",
|
|
"torch.bmm",
|
|
"torch.matmul",
|
|
# Neural network operations
|
|
"torch.nn.functional.embedding",
|
|
"torch.nn.functional.linear",
|
|
# Activation functions
|
|
"torch.nn.functional.relu",
|
|
"torch.nn.functional.leaky_relu",
|
|
"torch.nn.functional.elu",
|
|
"torch.nn.functional.gelu",
|
|
"torch.nn.functional.silu",
|
|
"torch.sigmoid",
|
|
"torch.tanh",
|
|
"torch.nn.functional.softmax",
|
|
# Normalization layers
|
|
"torch.nn.functional.layer_norm",
|
|
"torch.nn.functional.rms_norm",
|
|
"torch.nn.functional.batch_norm",
|
|
"torch.nn.functional.group_norm",
|
|
# Regularization
|
|
"torch.nn.functional.dropout",
|
|
],
|
|
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
|
|
)
|
|
|
|
def spec_distribution(self):
|
|
"""Default template: tensor-only (no scalars)."""
|
|
return {
|
|
"tensor_prob": 1.0,
|
|
"scalar_prob": 0.0,
|
|
"allow_tensors": True,
|
|
"allow_scalars": False,
|
|
}
|
|
|
|
def imports_codegen(self):
|
|
return [
|
|
"import torch",
|
|
]
|
|
|
|
def flags_codegen(self):
|
|
return ["torch._dynamo.config.capture_scalar_outputs = True"]
|
|
|
|
def epilogue_codegen(self):
|
|
return []
|
|
|
|
|
|
class DTensorFuzzTemplate(FuzzTemplate):
|
|
def __init__(self):
|
|
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
|
|
|
super().__init__(
|
|
supported_ops=[
|
|
"torch.add",
|
|
"torch.sub",
|
|
"torch.mul",
|
|
"torch.div",
|
|
"torch.mm",
|
|
"torch.addmm",
|
|
"torch.bmm",
|
|
"torch.matmul",
|
|
],
|
|
check=EagerVsFullGraphDynamicCompileCheck(),
|
|
)
|
|
|
|
def supported_dtypes(self):
|
|
"""Return list of DTensor-compatible dtypes (no complex types)."""
|
|
return [
|
|
torch.float32,
|
|
torch.float64,
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.bool,
|
|
]
|
|
|
|
def spec_distribution(self):
|
|
"""DTensor template: tensor-only (no scalars)."""
|
|
return {
|
|
"tensor_prob": 1.0,
|
|
"scalar_prob": 0.0,
|
|
"allow_tensors": True,
|
|
"allow_scalars": False,
|
|
}
|
|
|
|
def imports_codegen(self):
|
|
return [
|
|
"import torch",
|
|
"from torch.distributed.tensor.placement_types import Replicate, Shard",
|
|
"from torch.testing._internal.distributed.fake_pg import FakeStore",
|
|
"from torch.distributed.tensor import DTensor",
|
|
]
|
|
|
|
def flags_codegen(self):
|
|
return [
|
|
"torch._dynamo.config.capture_scalar_outputs = True",
|
|
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
|
|
"torch._inductor.config.emulate_precision_casts = True",
|
|
]
|
|
|
|
def args_codegen(self, arg_operations):
|
|
"""Generate DTensor argument creation code with proper mesh setup."""
|
|
code_lines = []
|
|
|
|
# Add DTensor setup code first
|
|
code_lines.extend(
|
|
[
|
|
"world_size = 1024",
|
|
"fake_store = FakeStore()",
|
|
"torch.distributed.init_process_group(",
|
|
' "fake", store=fake_store, rank=0, world_size=world_size',
|
|
")",
|
|
"",
|
|
"mesh = torch.distributed.device_mesh.init_device_mesh(",
|
|
' "cuda",',
|
|
" (2, 8),",
|
|
" mesh_dim_names=(",
|
|
' "dim1", "dim2",',
|
|
" ),",
|
|
")",
|
|
"",
|
|
"placements = (Replicate(), Replicate())",
|
|
"",
|
|
"# Sentinel tensor to ensure gradient computation",
|
|
"sentinel_local = torch.tensor(1.0, device='cuda', requires_grad=True)",
|
|
"sentinel = DTensor.from_local(sentinel_local, mesh, placements)",
|
|
"",
|
|
]
|
|
)
|
|
|
|
if arg_operations:
|
|
for i, (node_id, spec) in enumerate(arg_operations):
|
|
arg_name = f"arg_{i}"
|
|
|
|
if isinstance(spec, ScalarSpec):
|
|
# For scalars in DTensor, create a 0-dim tensor
|
|
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
|
code_lines.extend(
|
|
[
|
|
f"{arg_name}_local = torch.randn((), dtype={dtype_str}, device='cuda', requires_grad=True)",
|
|
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
|
]
|
|
)
|
|
|
|
elif isinstance(spec, TensorSpec):
|
|
size_str = str(spec.size)
|
|
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
|
|
|
# Handle different dtypes appropriately for DTensor
|
|
if spec.dtype in [
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.int8,
|
|
torch.int16,
|
|
]:
|
|
# Integer dtypes: use randint and no requires_grad
|
|
code_lines.extend(
|
|
[
|
|
f"{arg_name}_local = torch.randint(1, 10, {size_str}, dtype={dtype_str}, device='cuda')",
|
|
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
|
]
|
|
)
|
|
elif spec.dtype == torch.bool:
|
|
# Boolean dtype: use randint and cast to bool
|
|
code_lines.extend(
|
|
[
|
|
f"{arg_name}_local = torch.randint(0, 2, {size_str}, device='cuda').bool()",
|
|
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
|
]
|
|
)
|
|
else:
|
|
# Float dtypes: use randn and requires_grad
|
|
code_lines.extend(
|
|
[
|
|
f"{arg_name}_local = torch.randn({size_str}, dtype={dtype_str}, device='cuda', requires_grad=True)",
|
|
f"{arg_name} = DTensor.from_local({arg_name}_local, mesh, placements)",
|
|
]
|
|
)
|
|
|
|
return code_lines
|
|
|
|
def epilogue_codegen(self):
|
|
return ["torch.distributed.destroy_process_group()"]
|
|
|
|
|
|
class UnbackedFuzzTemplate(FuzzTemplate):
|
|
def __init__(self):
|
|
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
|
|
|
super().__init__(
|
|
supported_ops=[
|
|
"torch.ops.aten.item",
|
|
"torch.ops.aten.nonzero",
|
|
"torch.ops.aten.masked_select",
|
|
"torch.ops.aten.unique",
|
|
# Basic arithmetic operations
|
|
"torch.add",
|
|
"torch.sub",
|
|
"torch.mul",
|
|
"torch.div",
|
|
# Tensor shape operations
|
|
"torch.Tensor.view",
|
|
"torch.reshape",
|
|
"torch.flatten",
|
|
"torch.squeeze",
|
|
"torch.unsqueeze",
|
|
# Matrix operations
|
|
"torch.mm",
|
|
"torch.addmm",
|
|
"torch.bmm",
|
|
"torch.matmul",
|
|
# Neural network operations
|
|
"torch.nn.functional.embedding",
|
|
"torch.nn.functional.linear",
|
|
# Activation functions
|
|
"torch.nn.functional.relu",
|
|
"torch.nn.functional.leaky_relu",
|
|
"torch.nn.functional.elu",
|
|
"torch.nn.functional.gelu",
|
|
"torch.nn.functional.silu",
|
|
"torch.sigmoid",
|
|
"torch.tanh",
|
|
"torch.nn.functional.softmax",
|
|
# Normalization layers
|
|
"torch.nn.functional.layer_norm",
|
|
"torch.nn.functional.rms_norm",
|
|
"torch.nn.functional.batch_norm",
|
|
"torch.nn.functional.group_norm",
|
|
# Regularization
|
|
"torch.nn.functional.dropout",
|
|
],
|
|
check=EagerVsFullGraphDynamicCompileCheck(),
|
|
)
|
|
|
|
def supported_dtypes(self):
|
|
"""Return list of dtypes good for data-dependent operations."""
|
|
# Focus on dtypes that work well with data-dependent ops and arithmetic
|
|
# Exclude bool since arithmetic operations don't work with boolean tensors
|
|
return [
|
|
torch.float32,
|
|
torch.float64,
|
|
torch.int32,
|
|
torch.int64,
|
|
]
|
|
|
|
def spec_distribution(self):
|
|
"""Unbacked template: 50% tensors, 50% scalars."""
|
|
return {
|
|
"tensor_prob": 0.5,
|
|
"scalar_prob": 0.5,
|
|
"allow_tensors": True,
|
|
"allow_scalars": True,
|
|
}
|
|
|
|
def imports_codegen(self):
|
|
return [
|
|
"import torch",
|
|
]
|
|
|
|
def flags_codegen(self):
|
|
return [
|
|
"torch._dynamo.config.capture_scalar_outputs = True",
|
|
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
|
|
]
|
|
|
|
def epilogue_codegen(self):
|
|
return []
|
|
|
|
|
|
def convert_graph_to_python_code(
|
|
operation_graph: OperationGraph,
|
|
seed: Optional[int] = None,
|
|
template: str = "default",
|
|
) -> str:
|
|
"""
|
|
Convert an operation graph to executable Python code using topological ordering.
|
|
|
|
The graph-based approach generates code by:
|
|
1. Getting the topological order of nodes (dependencies before dependents)
|
|
2. Generating code for each node in that order
|
|
3. Properly handling input dependencies through node connections
|
|
|
|
Args:
|
|
operation_graph: OperationGraph instance containing the operation DAG
|
|
seed: Random seed for reproducible code generation. If None, uses current random state.
|
|
|
|
Returns:
|
|
String containing the complete Python code that executes the operations
|
|
"""
|
|
|
|
# Instantiate template
|
|
if template == "dtensor":
|
|
fuzz_template = DTensorFuzzTemplate()
|
|
elif template == "unbacked":
|
|
fuzz_template = UnbackedFuzzTemplate()
|
|
else:
|
|
fuzz_template = DefaultFuzzTemplate()
|
|
|
|
# Set seed for reproducible code generation
|
|
if seed is not None:
|
|
import random
|
|
|
|
random.seed(seed + 1000) # Offset to avoid conflicts with graph generation
|
|
torch.manual_seed(seed + 1000)
|
|
|
|
if not operation_graph.nodes:
|
|
raise ValueError("Empty operation graph")
|
|
|
|
# Get topological order - this ensures dependencies are processed before dependents
|
|
topo_order = operation_graph.get_topological_order()
|
|
|
|
# Track generated variables and arg operations
|
|
generated_code_lines = []
|
|
node_variables: dict[str, tuple[str, Spec]] = {} # Maps node_id to (var_name, spec)
|
|
arg_operations: list[
|
|
tuple[str, Spec]
|
|
] = [] # List of (node_id, spec) for arg operations
|
|
|
|
# Process nodes in topological order
|
|
for node_id in topo_order:
|
|
node = operation_graph.nodes[node_id]
|
|
op_name = node.op_name
|
|
output_spec = node.output_spec
|
|
|
|
# Generate output variable name
|
|
output_var_name = f"var_{node_id}"
|
|
|
|
# Generate input variable names from input nodes
|
|
input_var_names = []
|
|
for input_node_id in node.input_nodes:
|
|
if input_node_id in node_variables:
|
|
input_var_name, _ = node_variables[input_node_id]
|
|
input_var_names.append(input_var_name)
|
|
else:
|
|
raise ValueError(
|
|
f"Node {node_id} depends on {input_node_id}, but {input_node_id} "
|
|
f"was not processed yet. Topological order may be incorrect."
|
|
)
|
|
|
|
# Handle different operation types
|
|
if op_name == "arg" or op_name.startswith("arg_"):
|
|
# Track arg operations for later function signature generation
|
|
arg_operations.append((node_id, output_spec))
|
|
arg_name = f"arg_{len(arg_operations) - 1}"
|
|
# Add tensor descriptor comment for arg operations too
|
|
descriptor_comment = f"# {format_tensor_descriptor(output_spec)}"
|
|
operation_lines = [f"{output_var_name} = {arg_name} " + descriptor_comment]
|
|
else:
|
|
# Generate operation execution code
|
|
operation_lines = generate_simple_operation_code(
|
|
output_var_name, input_var_names, op_name, output_spec
|
|
)
|
|
|
|
# Add proper indentation for function body
|
|
generated_code_lines.extend([" " + line for line in operation_lines])
|
|
|
|
# Track this node's variable
|
|
node_variables[node_id] = (output_var_name, output_spec)
|
|
|
|
# The final result comes from the root node
|
|
root_node_id = operation_graph.root_node_id
|
|
if root_node_id not in node_variables:
|
|
raise ValueError(f"Root node {root_node_id} was not processed")
|
|
|
|
final_var_name, _ = node_variables[root_node_id]
|
|
|
|
# Generate function signature based on discovered arg operations
|
|
if arg_operations:
|
|
arg_names = [f"arg_{i}" for i in range(len(arg_operations))]
|
|
function_signature = f"def fuzzed_program({', '.join(arg_names)}, sentinel)"
|
|
else:
|
|
function_signature = "def fuzzed_program(sentinel)"
|
|
|
|
# Build the complete code - all imports at the top
|
|
code_lines = []
|
|
|
|
# Add template imports
|
|
code_lines.extend(fuzz_template.imports_codegen())
|
|
|
|
# Add template flags
|
|
code_lines.extend(fuzz_template.flags_codegen())
|
|
code_lines.append("")
|
|
|
|
# Add single seed at the top if seed is provided
|
|
if seed is not None:
|
|
code_lines.append(f"torch.manual_seed({seed})")
|
|
code_lines.append("")
|
|
|
|
code_lines.append(function_signature + ":")
|
|
|
|
# Add the generated operation code
|
|
code_lines.extend(generated_code_lines)
|
|
|
|
# Add return statement with sentinel multiplication to ensure gradient computation
|
|
# Handle complex tensors appropriately based on template
|
|
if template == "dtensor":
|
|
# For DTensor, avoid .real operation which doesn't work with sharding
|
|
# Instead use abs() for complex tensors to get a real result
|
|
code_lines.extend(
|
|
[
|
|
" # Ensure gradient computation by multiplying with sentinel",
|
|
f" result = {final_var_name} * sentinel",
|
|
" if result.is_complex():",
|
|
" result = result.abs() # Use abs() instead of .real for DTensor compatibility",
|
|
" return result",
|
|
"",
|
|
]
|
|
)
|
|
else:
|
|
code_lines.extend(
|
|
[
|
|
" # Ensure gradient computation by multiplying with sentinel and taking real part",
|
|
f" result = {final_var_name} * sentinel",
|
|
" if result.is_complex():",
|
|
" result = result.real",
|
|
" return result",
|
|
"",
|
|
]
|
|
)
|
|
|
|
# Generate argument creation code using template
|
|
arg_code_lines = fuzz_template.args_codegen(arg_operations)
|
|
code_lines.extend(arg_code_lines)
|
|
|
|
# Generate the final execution with both normal and compiled versions
|
|
if arg_operations:
|
|
arg_names = [f"arg_{i}" for i in range(len(arg_operations))]
|
|
if len(arg_names) == 1:
|
|
args_tuple = (
|
|
f"({arg_names[0]},)" # Single element tuple needs trailing comma
|
|
)
|
|
else:
|
|
args_tuple = f"({', '.join(arg_names)})"
|
|
else:
|
|
args_tuple = "()"
|
|
|
|
# Generate execution code using template check
|
|
check_lines = fuzz_template.check.codegen(f"{args_tuple} + (sentinel,)")
|
|
code_lines.extend([""] + check_lines)
|
|
|
|
# Add template epilogue
|
|
epilogue_lines = fuzz_template.epilogue_codegen()
|
|
if epilogue_lines:
|
|
code_lines.append("")
|
|
code_lines.extend(epilogue_lines)
|
|
|
|
return "\n".join(code_lines)
|
|
|
|
|
|
def generate_simple_operation_code(
|
|
output_var: str,
|
|
input_vars: list,
|
|
op_name: str,
|
|
output_spec,
|
|
) -> list:
|
|
"""
|
|
Generate code lines for executing a single operation using class-based operators.
|
|
|
|
Args:
|
|
output_var: Name of the output variable
|
|
input_vars: List of input variable names
|
|
op_name: Name of the operation
|
|
output_spec: Output specification for the operation
|
|
"""
|
|
# Try to get the operator from the registry
|
|
operator = get_operator(op_name)
|
|
|
|
if operator is not None:
|
|
# Use the class-based operator to generate code
|
|
code_line = operator.codegen(output_var, input_vars, output_spec)
|
|
# Add tensor descriptor comment
|
|
descriptor_comment = f"# {format_tensor_descriptor(output_spec)}"
|
|
return [code_line + " " + descriptor_comment]
|
|
else:
|
|
# Fallback for unknown operations
|
|
return [f"# Unknown operation: {op_name}"]
|
|
|
|
|
|
def create_program_file(python_code: str) -> str:
|
|
"""
|
|
Create a temporary Python file from the generated code.
|
|
|
|
Args:
|
|
python_code: String containing Python code to write
|
|
|
|
Returns:
|
|
Path to the created temporary file
|
|
"""
|
|
import hashlib
|
|
|
|
# Generate a deterministic filename based on code content hash
|
|
code_hash = hashlib.md5(python_code.encode()).hexdigest()[:8] # noqa: S324
|
|
tmp_dir = "/tmp/torchfuzz"
|
|
os.makedirs(tmp_dir, exist_ok=True)
|
|
generated_file_path = os.path.join(tmp_dir, f"fuzz_{code_hash}.py")
|
|
|
|
# Write the generated code to the specified file
|
|
with open(generated_file_path, "w") as f:
|
|
f.write(python_code)
|
|
|
|
return generated_file_path
|