[torchfuzz] Encapsulate fuzzing and codegen logic into ops (#163547)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163547
Approved by: https://github.com/laithsakka
This commit is contained in:
bobrenjc93
2025-09-22 14:12:00 -07:00
committed by PyTorch MergeBot
parent 27164b6788
commit 0b75a16200
16 changed files with 809 additions and 357 deletions

View File

@ -196,6 +196,7 @@ exclude_patterns = [
'tools/test/gen_operators_yaml_test.py',
'tools/test/gen_oplist_test.py',
'tools/test/test_selective_build.py',
'tools/experimental/dynamic_shapes/torchfuzz/**',
]
command = [
'python3',

View File

@ -1,49 +1,19 @@
# mypy: ignore-errors
"""Torchfuzz package for generating and testing random PyTorch operations."""
"""
PyTorch Operation Fuzzer for Dynamic Shapes Testing.
This package provides comprehensive fuzzing tools for testing torch.compile
with dynamic shapes and diverse operation patterns.
"""
from fuzzer import fuzz_and_execute, fuzz_operation_stack
from ops_fuzzer import fuzz_op, fuzz_spec, Operation
from tensor_fuzzer import (
fuzz_scalar,
fuzz_tensor,
fuzz_tensor_simple,
fuzz_tensor_size,
fuzz_torch_tensor_type,
fuzz_valid_stride,
ScalarSpec,
Spec,
TensorSpec,
test_fuzzing_tensors,
)
from visualize_stack import operation_stack_to_dot, visualize_operation_stack
# Make key classes available at package level
from .operators import get_operator, list_operators, register_operator
from .ops_fuzzer import fuzz_operation_graph, fuzz_spec, OperationGraph
from .tensor_fuzzer import ScalarSpec, Spec, TensorSpec
__all__ = [
# Core fuzzing functionality
"fuzz_and_execute",
"fuzz_operation_stack",
"fuzz_op",
"fuzz_spec",
# Tensor fuzzing
"fuzz_tensor",
"fuzz_tensor_simple",
"fuzz_tensor_size",
"fuzz_torch_tensor_type",
"fuzz_valid_stride",
"fuzz_scalar",
"test_fuzzing_tensors",
# Data types and configuration
"Operation",
"TensorSpec",
"ScalarSpec",
"Spec",
# Visualization
"operation_stack_to_dot",
"visualize_operation_stack",
"OperationGraph",
"fuzz_operation_graph",
"fuzz_spec",
"get_operator",
"register_operator",
"list_operators",
]

View File

@ -9,11 +9,12 @@ from queue import Empty, Queue
from threading import Thread
from typing import Any, Optional, Union
from ops_fuzzer import OperationGraph
from tensor_fuzzer import fuzz_scalar, fuzz_tensor_simple, ScalarSpec, Spec, TensorSpec
import torch
from torchfuzz.operators import get_operator
from torchfuzz.ops_fuzzer import OperationGraph
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
def convert_graph_to_python_code(
operation_graph: OperationGraph, seed: Optional[int] = None
@ -268,7 +269,7 @@ def generate_simple_operation_code(
output_spec,
) -> list:
"""
Generate code lines for executing a single operation (simplified version without arg_tracker).
Generate code lines for executing a single operation using class-based operators.
Args:
output_var: Name of the output variable
@ -276,83 +277,15 @@ def generate_simple_operation_code(
op_name: Name of the operation
output_spec: Output specification for the operation
"""
if op_name == "scalar_add":
return [f"{output_var} = {input_vars[0]} + {input_vars[1]}"]
elif op_name == "scalar_multiply":
return [f"{output_var} = {input_vars[0]} * {input_vars[1]}"]
elif op_name == "torch.ops.aten.item":
return [f"{output_var} = {input_vars[0]}.item()"]
elif op_name == "torch.ops.aten.add":
return [f"{output_var} = torch.ops.aten.add({input_vars[0]}, {input_vars[1]})"]
elif op_name == "torch.ops.aten.mul":
return [f"{output_var} = torch.ops.aten.mul({input_vars[0]}, {input_vars[1]})"]
elif op_name == "constant":
# Create constant by calling fuzzing functions during codegen with deterministic seed
# Use a deterministic seed based on the variable name to ensure reproducibility
var_seed = hash(output_var) % (2**31)
if isinstance(output_spec, ScalarSpec):
# Call fuzz_scalar during codegen and embed the result
actual_value = fuzz_scalar(output_spec, seed=var_seed)
# Format the value for embedding in code
if isinstance(actual_value, bool):
value_str = str(actual_value)
elif isinstance(actual_value, (int, float)):
value_str = repr(actual_value)
elif isinstance(actual_value, complex):
value_str = f"complex({actual_value.real}, {actual_value.imag})"
else:
value_str = repr(actual_value)
return [f"{output_var} = {value_str}"]
elif isinstance(output_spec, TensorSpec):
# Call fuzz_tensor_simple during codegen and embed the result
actual_tensor = fuzz_tensor_simple(
output_spec.size, output_spec.stride, output_spec.dtype, seed=var_seed
)
# Convert tensor to code representation
size_str = str(output_spec.size)
dtype_str = f"torch.{output_spec.dtype}".replace("torch.torch.", "torch.")
# Handle empty tensors (with 0 elements)
if actual_tensor.numel() == 0:
# For empty tensors, use a default fill value based on dtype
default_values = {
torch.float16: 0.0,
torch.float32: 0.0,
torch.float64: 0.0,
torch.bfloat16: 0.0,
torch.int8: 0,
torch.int16: 0,
torch.int32: 0,
torch.int64: 0,
torch.bool: False,
torch.complex64: 0.0,
torch.complex128: 0.0,
}
fill_value = default_values.get(output_spec.dtype, 0)
return [
f"{output_var} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
]
else:
# For non-empty tensors, use the first element as fill value
fill_value = actual_tensor.flatten()[0].item()
return [
f"{output_var} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
]
else:
return [f"# Unknown output spec type for constant: {type(output_spec)}"]
# Try to get the operator from the registry
operator = get_operator(op_name)
if operator is not None:
# Use the class-based operator to generate code
code_line = operator.codegen(output_var, input_vars, output_spec)
return [code_line]
else:
# Fallback for unknown operations
return [f"# Unknown operation: {op_name}"]

View File

@ -1,13 +1,21 @@
# mypy: ignore-errors
import logging
import os
import random
import sys
from typing import Any, Optional, Union
from codegen import convert_graph_to_python_code, execute_python_code
from ops_fuzzer import fuzz_operation_graph, fuzz_spec
from visualize_graph import visualize_operation_graph
# Add parent directory to path so we can import torchfuzz as a module
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
import torch
from torchfuzz.codegen import convert_graph_to_python_code, execute_python_code
from torchfuzz.ops_fuzzer import fuzz_operation_graph, fuzz_spec
from torchfuzz.visualize_graph import visualize_operation_graph
def fuzz_and_execute(

View File

@ -0,0 +1,26 @@
"""Torchfuzz operators module."""
from torchfuzz.operators.add import AddOperator
from torchfuzz.operators.arg import ArgOperator
from torchfuzz.operators.base import Operator
from torchfuzz.operators.constant import ConstantOperator
from torchfuzz.operators.item import ItemOperator
from torchfuzz.operators.mul import MulOperator
from torchfuzz.operators.registry import get_operator, list_operators, register_operator
from torchfuzz.operators.scalar_add import ScalarAddOperator
from torchfuzz.operators.scalar_multiply import ScalarMultiplyOperator
__all__ = [
"Operator",
"AddOperator",
"MulOperator",
"ItemOperator",
"ScalarAddOperator",
"ScalarMultiplyOperator",
"ConstantOperator",
"ArgOperator",
"get_operator",
"register_operator",
"list_operators",
]

View File

@ -0,0 +1,71 @@
"""Add operator implementation."""
import random
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
from torchfuzz.type_promotion import (
get_dtype_map,
get_dtype_name,
get_promotion_table_for_strings,
)
class AddOperator(Operator):
"""Operator for element-wise addition."""
def __init__(self):
super().__init__("torch.ops.aten.add")
def can_produce(self, output_spec: Spec) -> bool:
"""Add can produce tensors but not scalars."""
return isinstance(output_spec, TensorSpec)
def supports_variable_inputs(self) -> bool:
"""Add operator supports variable number of inputs."""
return True
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
"""Decompose tensor into input tensors for addition with type promotion."""
if not isinstance(output_spec, TensorSpec):
raise ValueError("AddOperator can only produce TensorSpec outputs")
# Use shared type promotion table
promotion_table = get_promotion_table_for_strings()
# If num_inputs > 2, promote left-to-right (e.g. (((a + b) + c) + d))
# For simplicity, we generate the first two with promotion, rest match output dtype
dtype_str = get_dtype_name(output_spec.dtype)
supported_types = promotion_table.get(dtype_str, [(dtype_str, dtype_str)])
# Pick a random promotion pattern for the first two inputs
if num_inputs >= 2:
dtypes = list(random.choice(supported_types))
# For >2 inputs, fill with output dtype
while len(dtypes) < num_inputs:
dtypes.append(dtype_str)
else:
dtypes = [dtype_str] * num_inputs
# Convert dtype strings back to torch dtypes
dtype_map = get_dtype_map()
return [
TensorSpec(
size=output_spec.size,
stride=output_spec.stride,
dtype=dtype_map.get(dt, output_spec.dtype),
)
for dt in dtypes
]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for addition operation."""
if len(input_names) == 2:
return f"{output_name} = torch.ops.aten.add({input_names[0]}, {input_names[1]})"
else:
# Sum all input tensors
expr = " + ".join(input_names)
return f"{output_name} = {expr}"

View File

@ -0,0 +1,31 @@
"""Arg operator implementation."""
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec
class ArgOperator(Operator):
"""Operator for function arguments/parameters."""
def __init__(self):
super().__init__("arg")
def can_produce(self, output_spec: Spec) -> bool:
"""Arg can produce any type of output."""
return True
def supports_variable_inputs(self) -> bool:
"""Arg operator does not require inputs."""
return False
def decompose(self, output_spec: Spec, num_inputs: int = 0) -> list[Spec]:
"""Arg requires no inputs."""
return []
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for arg operation."""
# The actual argument name assignment will be handled separately
# in the codegen.py when processing arg operations
return f"# {output_name} will be assigned an argument value"

View File

@ -0,0 +1,39 @@
"""Base operator implementation."""
from abc import ABC, abstractmethod
from torchfuzz.tensor_fuzzer import Spec
class Operator(ABC):
"""Base class for all operators in torchfuzz."""
def __init__(self, name: str):
"""Initialize operator with name."""
self.name = name
@abstractmethod
def can_produce(self, output_spec: Spec) -> bool:
"""Check if this operator can produce the given output spec."""
@abstractmethod
def supports_variable_inputs(self) -> bool:
"""Check if this operator supports variable number of inputs."""
@abstractmethod
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
"""Decompose output spec into input specs."""
@abstractmethod
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for this operation."""
def __str__(self) -> str:
"""String representation of the operator."""
return f"{self.__class__.__name__}({self.name})"
def __repr__(self) -> str:
"""Repr representation of the operator."""
return self.__str__()

View File

@ -0,0 +1,91 @@
"""Constant operator implementation."""
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import (
fuzz_scalar,
fuzz_tensor_simple,
ScalarSpec,
Spec,
TensorSpec,
)
class ConstantOperator(Operator):
"""Operator for generating constants."""
def __init__(self):
super().__init__("constant")
def can_produce(self, output_spec: Spec) -> bool:
"""Constant can produce any type of output."""
return True
def supports_variable_inputs(self) -> bool:
"""Constant operator does not require inputs."""
return False
def decompose(self, output_spec: Spec, num_inputs: int = 0) -> list[Spec]:
"""Constant requires no inputs."""
return []
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for constant creation."""
# Create constant by calling fuzzing functions during codegen with deterministic seed
# Use a deterministic seed based on the variable name to ensure reproducibility
var_seed = hash(output_name) % (2**31)
if isinstance(output_spec, ScalarSpec):
# Call fuzz_scalar during codegen and embed the result
actual_value = fuzz_scalar(output_spec, seed=var_seed)
# Format the value for embedding in code
if isinstance(actual_value, bool):
value_str = str(actual_value)
elif isinstance(actual_value, (int, float)):
value_str = repr(actual_value)
elif isinstance(actual_value, complex):
value_str = f"complex({actual_value.real}, {actual_value.imag})"
else:
value_str = repr(actual_value)
return f"{output_name} = {value_str}"
elif isinstance(output_spec, TensorSpec):
# Call fuzz_tensor_simple during codegen and embed the result
actual_tensor = fuzz_tensor_simple(
output_spec.size, output_spec.stride, output_spec.dtype, seed=var_seed
)
# Convert tensor to code representation
size_str = str(output_spec.size)
dtype_str = f"torch.{output_spec.dtype}".replace("torch.torch.", "torch.")
# Handle empty tensors (with 0 elements)
if actual_tensor.numel() == 0:
# For empty tensors, use a default fill value based on dtype
import torch
default_values = {
torch.float16: 0.0,
torch.float32: 0.0,
torch.float64: 0.0,
torch.bfloat16: 0.0,
torch.int8: 0,
torch.int16: 0,
torch.int32: 0,
torch.int64: 0,
torch.bool: False,
torch.complex64: 0.0,
torch.complex128: 0.0,
}
fill_value = default_values.get(output_spec.dtype, 0)
return f"{output_name} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
else:
# For non-empty tensors, use the first element as fill value
fill_value = actual_tensor.flatten()[0].item()
return f"{output_name} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
else:
return f"# Unknown output spec type for constant: {type(output_spec)}"

View File

@ -0,0 +1,39 @@
"""Item operator implementation."""
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
class ItemOperator(Operator):
"""Operator for extracting a scalar from a tensor."""
def __init__(self):
super().__init__("torch.ops.aten.item")
def can_produce(self, output_spec: Spec) -> bool:
"""Item can only produce scalars."""
return isinstance(output_spec, ScalarSpec)
def supports_variable_inputs(self) -> bool:
"""Item operator does not support variable number of inputs."""
return False
def decompose(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
"""Decompose scalar into a single-element tensor for item operation."""
if not isinstance(output_spec, ScalarSpec):
raise ValueError("ItemOperator can only produce ScalarSpec outputs")
# Create a tensor spec that can produce a scalar via .item()
# Use a 1-D tensor with 1 element
tensor_spec = TensorSpec(size=(1,), stride=(1,), dtype=output_spec.dtype)
return [tensor_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for item operation."""
if len(input_names) != 1:
raise ValueError("ItemOperator requires exactly one input")
return f"{output_name} = {input_names[0]}.item()"

View File

@ -0,0 +1,72 @@
"""Multiply operator implementation."""
import random
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
class MulOperator(Operator):
"""Operator for element-wise multiplication."""
def __init__(self):
super().__init__("torch.ops.aten.mul")
def can_produce(self, output_spec: Spec) -> bool:
"""Mul can produce tensors but not scalars."""
return isinstance(output_spec, TensorSpec)
def supports_variable_inputs(self) -> bool:
"""Mul operator supports variable number of inputs."""
return True
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
"""Decompose tensor into input tensors for multiplication with type promotion."""
if not isinstance(output_spec, TensorSpec):
raise ValueError("MulOperator can only produce TensorSpec outputs")
# Use shared type promotion table
from torchfuzz.type_promotion import (
get_dtype_map,
get_dtype_name,
get_promotion_table_for_strings,
)
promotion_table = get_promotion_table_for_strings()
# If num_inputs > 2, promote left-to-right (e.g. (((a * b) * c) * d))
# For simplicity, we generate the first two with promotion, rest match output dtype
dtype_str = get_dtype_name(output_spec.dtype)
supported_types = promotion_table.get(dtype_str, [(dtype_str, dtype_str)])
# Pick a random promotion pattern for the first two inputs
if num_inputs >= 2:
dtypes = list(random.choice(supported_types))
# For >2 inputs, fill with output dtype
while len(dtypes) < num_inputs:
dtypes.append(dtype_str)
else:
dtypes = [dtype_str] * num_inputs
# Convert dtype strings back to torch dtypes
dtype_map = get_dtype_map()
return [
TensorSpec(
size=output_spec.size,
stride=output_spec.stride,
dtype=dtype_map.get(dt, output_spec.dtype),
)
for dt in dtypes
]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for multiplication operation."""
if len(input_names) == 2:
return f"{output_name} = torch.ops.aten.mul({input_names[0]}, {input_names[1]})"
else:
# Multiply all input tensors
expr = " * ".join(input_names)
return f"{output_name} = {expr}"

View File

@ -0,0 +1,62 @@
"""Operator registry for mapping operation names to operator instances."""
from typing import Optional
from torchfuzz.operators.add import AddOperator
from torchfuzz.operators.arg import ArgOperator
from torchfuzz.operators.base import Operator
from torchfuzz.operators.constant import ConstantOperator
from torchfuzz.operators.item import ItemOperator
from torchfuzz.operators.mul import MulOperator
from torchfuzz.operators.scalar_add import ScalarAddOperator
from torchfuzz.operators.scalar_multiply import ScalarMultiplyOperator
class OperatorRegistry:
"""Registry for managing operator instances."""
def __init__(self):
"""Initialize the registry with default operators."""
self._operators: dict[str, Operator] = {}
self._register_default_operators()
def _register_default_operators(self):
"""Register the default set of operators."""
self.register(AddOperator())
self.register(MulOperator())
self.register(ItemOperator())
self.register(ScalarAddOperator())
self.register(ScalarMultiplyOperator())
self.register(ConstantOperator())
self.register(ArgOperator())
def register(self, operator: Operator):
"""Register an operator in the registry."""
self._operators[operator.name] = operator
def get(self, op_name: str) -> Optional[Operator]:
"""Get an operator by name."""
return self._operators.get(op_name)
def list_operators(self) -> dict[str, Operator]:
"""List all registered operators."""
return self._operators.copy()
# Global registry instance
_global_registry = OperatorRegistry()
def get_operator(op_name: str) -> Optional[Operator]:
"""Get an operator from the global registry."""
return _global_registry.get(op_name)
def register_operator(operator: Operator):
"""Register an operator in the global registry."""
_global_registry.register(operator)
def list_operators() -> dict[str, Operator]:
"""List all operators in the global registry."""
return _global_registry.list_operators()

View File

@ -0,0 +1,43 @@
"""Scalar add operator implementation."""
import random
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec
class ScalarAddOperator(Operator):
"""Operator for adding two scalars."""
def __init__(self):
super().__init__("scalar_add")
def can_produce(self, output_spec: Spec) -> bool:
"""Scalar add can only produce scalars."""
return isinstance(output_spec, ScalarSpec)
def supports_variable_inputs(self) -> bool:
"""Scalar add operator does not support variable number of inputs."""
return False
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
"""Decompose scalar into input scalars for addition with type promotion."""
if not isinstance(output_spec, ScalarSpec):
raise ValueError("ScalarAddOperator can only produce ScalarSpec outputs")
# Use shared type promotion utility
from torchfuzz.type_promotion import get_scalar_promotion_pairs
supported_types = get_scalar_promotion_pairs(output_spec.dtype)
dtypes = random.choice(supported_types)
return [ScalarSpec(dtype=dtypes[0]), ScalarSpec(dtype=dtypes[1])]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for scalar addition operation."""
if len(input_names) != 2:
raise ValueError("ScalarAddOperator requires exactly two inputs")
return f"{output_name} = {input_names[0]} + {input_names[1]}"

View File

@ -0,0 +1,45 @@
"""Scalar multiply operator implementation."""
import random
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec
class ScalarMultiplyOperator(Operator):
"""Operator for multiplying two scalars."""
def __init__(self):
super().__init__("scalar_multiply")
def can_produce(self, output_spec: Spec) -> bool:
"""Scalar multiply can only produce scalars."""
return isinstance(output_spec, ScalarSpec)
def supports_variable_inputs(self) -> bool:
"""Scalar multiply operator does not support variable number of inputs."""
return False
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
"""Decompose scalar into input scalars for multiplication with type promotion."""
if not isinstance(output_spec, ScalarSpec):
raise ValueError(
"ScalarMultiplyOperator can only produce ScalarSpec outputs"
)
# Use shared type promotion utility
from torchfuzz.type_promotion import get_scalar_promotion_pairs
supported_types = get_scalar_promotion_pairs(output_spec.dtype)
dtypes = random.choice(supported_types)
return [ScalarSpec(dtype=dtypes[0]), ScalarSpec(dtype=dtypes[1])]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for scalar multiplication operation."""
if len(input_names) != 2:
raise ValueError("ScalarMultiplyOperator requires exactly two inputs")
return f"{output_name} = {input_names[0]} * {input_names[1]}"

View File

@ -4,7 +4,10 @@ import random
from dataclasses import dataclass
from typing import Optional
from tensor_fuzzer import (
import torch
from torchfuzz.operators import get_operator, list_operators
from torchfuzz.tensor_fuzzer import (
fuzz_tensor_size,
fuzz_torch_tensor_type,
fuzz_valid_stride,
@ -14,8 +17,6 @@ from tensor_fuzzer import (
TensorSpec,
)
import torch
@dataclass
class OperationNode:
@ -172,11 +173,7 @@ def fuzz_spec() -> Spec:
def fuzz_op(target_spec: Spec, depth, stack_size) -> tuple[str, list[Spec]]:
"""
Given an output specification, returns an operation that can
produce a tensor with that layout.
Supports:
- For scalars: scalar_add, scalar_multiply, item, constant, arg
- For tensors: aten.add, aten.mul, constant, arg
produce a tensor with that layout using the operator class system.
Args:
target_spec: Desired output specification (TensorSpec or ScalarSpec)
@ -188,240 +185,70 @@ def fuzz_op(target_spec: Spec, depth, stack_size) -> tuple[str, list[Spec]]:
Tuple of (operation_name, list_of_argument_specs) where each argument spec
describes the layout requirements for the operation's inputs
"""
if isinstance(target_spec, ScalarSpec):
if target_spec.constant is not None:
# At depth 0, only allow constant operation
return _get_constant_args_specs(target_spec)
# Get all available operators
available_operators = list_operators()
# Filter operators that can produce the target spec
compatible_ops = []
for op_name, operator in available_operators.items():
if operator.can_produce(target_spec):
compatible_ops.append((op_name, operator))
if not compatible_ops:
raise ValueError(f"No operators available that can produce {target_spec}")
# Categorize operators into leaf and non-leaf
leaf_ops = []
non_leaf_ops = []
for op_name, operator in compatible_ops:
if op_name in ["constant", "arg"] or op_name.startswith("arg_"):
leaf_ops.append((op_name, operator))
else:
non_leaf_ops.append((op_name, operator))
# Choose operation based on depth and stack size constraints
if depth == 0:
# At depth 0, only allow leaf operations
ops = ["constant", "arg"]
chosen_op = random.choice(ops)
if not leaf_ops:
# If no leaf ops can produce this spec, fallback to arg
return _get_arg_args_specs(target_spec)
chosen_op_name, chosen_operator = random.choice(leaf_ops)
else:
# At higher depths, allow all scalar operations
non_leaf_ops = ["scalar_add", "scalar_multiply", "torch.ops.aten.item"]
leaf_ops = ["constant", "arg"]
# At higher depths, choose between leaf and non-leaf operations
# Reduce probability of leaf operations when stack_size < 10
if stack_size < 10 or depth > 7:
if (stack_size < 10 or depth > 7) and non_leaf_ops:
# 80% chance of non-leaf, 20% chance of leaf
if random.random() < 0.8:
chosen_op = random.choice(non_leaf_ops)
chosen_op_name, chosen_operator = random.choice(non_leaf_ops)
else:
chosen_op = random.choice(leaf_ops)
chosen_op_name, chosen_operator = (
random.choice(leaf_ops) if leaf_ops else random.choice(non_leaf_ops)
)
else:
# Normal probability distribution
all_ops = non_leaf_ops + leaf_ops
chosen_op = random.choice(all_ops)
chosen_op_name, chosen_operator = (
random.choice(all_ops) if all_ops else ("arg", get_operator("arg"))
)
if chosen_op == "scalar_add":
return _get_scalar_add_args_specs(target_spec)
elif chosen_op == "scalar_multiply":
return _get_scalar_multiply_args_specs(target_spec)
elif chosen_op == "torch.ops.aten.item":
return _get_item_args_specs(target_spec)
elif chosen_op == "constant":
return _get_constant_args_specs(target_spec)
else: # arg
# Use the operator to decompose the target spec into input specs
try:
if chosen_op_name.startswith("arg_"):
# Handle special arg_ operations
return chosen_op_name, []
elif chosen_op_name in ["constant", "arg"]:
# Handle leaf operations
return chosen_op_name, []
else:
# Use the operator's decompose method
input_specs = chosen_operator.decompose(target_spec)
return chosen_op_name, input_specs
except Exception as e:
# Fallback to arg if decomposition fails
print(f"Warning: operator {chosen_op_name} decomposition failed: {e}")
return _get_arg_args_specs(target_spec)
elif isinstance(target_spec, TensorSpec):
if depth == 0:
# At depth 0, only allow leaf operations
ops = ["arg"]
chosen_op = random.choice(ops)
else:
# At higher depths, allow all tensor operations
non_leaf_ops = [
"torch.ops.aten.add",
"torch.ops.aten.mul",
]
leaf_ops = ["arg"]
# Reduce probability of leaf operations when stack_size < 10
if stack_size < 10:
# 80% chance of non-leaf, 20% chance of leaf
if random.random() < 0.8:
chosen_op = random.choice(non_leaf_ops)
else:
chosen_op = random.choice(leaf_ops)
else:
# Normal probability distribution
all_ops = non_leaf_ops + leaf_ops
chosen_op = random.choice(all_ops)
if chosen_op == "torch.ops.aten.add":
return _get_aten_add_args_specs(target_spec)
elif chosen_op == "torch.ops.aten.mul":
return _get_aten_mul_args_specs(target_spec)
elif chosen_op == "constant":
return _get_constant_args_specs(target_spec)
else: # arg
return _get_arg_args_specs(target_spec)
else:
raise ValueError(f"Unknown target spec type: {type(target_spec)}")
def _get_scalar_add_args_specs(target_spec: ScalarSpec) -> tuple[str, list[Spec]]:
"""Get argument specifications for scalar_add operation using type promotion rules."""
# Use PyTorch's implicit type promotion rules to generate diverse input types
arg_specs = _get_promoted_scalar_args(target_spec.dtype)
return "scalar_add", arg_specs
def _get_scalar_multiply_args_specs(target_spec: ScalarSpec) -> tuple[str, list[Spec]]:
"""Get argument specifications for scalar_multiply operation using type promotion rules."""
# Use PyTorch's implicit type promotion rules to generate diverse input types
arg_specs = _get_promoted_scalar_args(target_spec.dtype)
return "scalar_multiply", arg_specs
# Define promotion chains - types that can promote to the target
# PyTorch promotion hierarchy (simplified):
# - bool < int8 < int16 < int32 < int64 < float16 < float32 < float64 < complex64 < complex128
# - uint types have limited promotion support
_PROMOTION_CHAINS = {
torch.bool: [torch.bool],
torch.int8: [torch.bool, torch.int8],
torch.int16: [torch.bool, torch.int8, torch.int16],
torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32],
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64],
torch.float16: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
],
torch.float32: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
],
torch.float64: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
],
torch.complex64: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.complex64,
],
torch.complex128: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
],
}
def _get_promoted_dtypes(target_dtype: torch.dtype) -> list[torch.dtype]:
"""
Generate two dtypes that will promote to target_dtype via PyTorch's type promotion rules.
"""
# Get compatible input types for the target dtype
compatible_types = _PROMOTION_CHAINS.get(target_dtype, [target_dtype])
# Strategy: Choose between same type or mixed promotion
strategies = ["same_type", "mixed_promotion"]
strategy = random.choice(strategies)
if strategy == "same_type":
# Both args same type as target
return [target_dtype, target_dtype]
else: # mixed_promotion
# Mixed types where the result will promote to target_dtype
lower_types = compatible_types[:-1] # All except the last (target_dtype)
if lower_types:
# One arg is target_dtype, one is lower (will promote to target)
lower_dtype = random.choice(lower_types)
if random.random() < 0.5:
return [target_dtype, lower_dtype]
else:
return [lower_dtype, target_dtype]
else:
# Fallback to same type if no lower types available
return [target_dtype, target_dtype]
def _get_promoted_scalar_args(target_dtype: torch.dtype) -> list[Spec]:
"""
Generate two argument specs that will promote to target_dtype via PyTorch's type promotion rules.
"""
arg_dtypes = _get_promoted_dtypes(target_dtype)
# For ScalarSpec output, both inputs must be ScalarSpec
# (mixing with 0-D TensorSpec would produce 0-D TensorSpec output)
return [ScalarSpec(arg_dtypes[0]), ScalarSpec(arg_dtypes[1])]
def _get_item_args_specs(target_spec: ScalarSpec) -> tuple[str, list[Spec]]:
"""Get argument specifications for torch.ops.aten.item operation."""
# torch.ops.aten.item: tensor -> scalar (extract single element)
# Create a tensor spec that can produce a scalar via .item()
tensor_spec = TensorSpec(
size=(1,), stride=(1,), dtype=target_spec.dtype
) # 1-D tensor with 1 element
arg_specs: list[Spec] = [tensor_spec]
return "torch.ops.aten.item", arg_specs
def _get_aten_add_args_specs(target_spec: TensorSpec) -> tuple[str, list[Spec]]:
"""Get argument specifications for torch.ops.aten.add operation using type promotion rules."""
# Use promotion rules to generate diverse tensor input types
arg_dtypes = _get_promoted_dtypes(target_spec.dtype)
arg_specs: list[Spec] = [
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[0]),
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[1]),
]
return "torch.ops.aten.add", arg_specs
def _get_aten_mul_args_specs(target_spec: TensorSpec) -> tuple[str, list[Spec]]:
"""Get argument specifications for torch.ops.aten.mul operation using type promotion rules."""
# Use promotion rules to generate diverse tensor input types
arg_dtypes = _get_promoted_dtypes(target_spec.dtype)
arg_specs: list[Spec] = [
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[0]),
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[1]),
]
return "torch.ops.aten.mul", arg_specs
def _get_constant_args_specs(target_spec: Spec) -> tuple[str, list[Spec]]:
"""Get argument specifications for constant operation."""
# Constant operation takes no arguments - generates a fixed constant value/tensor
return "constant", []
# Global counter for generating unique argument IDs
_next_arg_id = 0

View File

@ -0,0 +1,194 @@
"""Type promotion utilities for torchfuzz operators."""
import random
import torch
# Define promotion chains - types that can promote to the target
# PyTorch promotion hierarchy (simplified):
# - bool < int8 < int16 < int32 < int64 < float16 < float32 < float64 < complex64 < complex128
# - uint types have limited promotion support
PROMOTION_CHAINS = {
torch.bool: [torch.bool],
torch.int8: [torch.bool, torch.int8],
torch.int16: [torch.bool, torch.int8, torch.int16],
torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32],
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64],
torch.float16: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
],
torch.float32: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
],
torch.float64: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
],
torch.complex64: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.complex64,
],
torch.complex128: [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
],
}
def get_promoted_dtypes(target_dtype: torch.dtype) -> list[torch.dtype]:
"""
Generate two dtypes that will promote to target_dtype via PyTorch's type promotion rules.
"""
# Get compatible input types for the target dtype
compatible_types = PROMOTION_CHAINS.get(target_dtype, [target_dtype])
# Strategy: Choose between same type or mixed promotion
strategies = ["same_type", "mixed_promotion"]
strategy = random.choice(strategies)
if strategy == "same_type":
# Both args same type as target
return [target_dtype, target_dtype]
else: # mixed_promotion
# Mixed types where the result will promote to target_dtype
lower_types = compatible_types[:-1] # All except the last (target_dtype)
if lower_types:
# One arg is target_dtype, one is lower (will promote to target)
lower_dtype = random.choice(lower_types)
if random.random() < 0.5:
return [target_dtype, lower_dtype]
else:
return [lower_dtype, target_dtype]
else:
# Fallback to same type if no lower types available
return [target_dtype, target_dtype]
def get_dtype_name(dtype: torch.dtype) -> str:
"""Get string name for a torch dtype."""
return str(dtype).split(".")[-1]
def get_promotion_table_for_strings() -> dict:
"""
Get promotion table using string dtype names for backward compatibility.
Returns dictionary mapping output dtype string to possible input dtype string pairs.
"""
return {
"float32": [
("float32", "float32"),
("bfloat16", "float32"),
("float32", "bfloat16"),
("float16", "float32"),
("float32", "float16"),
],
"bfloat16": [
("bfloat16", "bfloat16"),
("float32", "bfloat16"),
("bfloat16", "float32"),
],
"float16": [
("float16", "float16"),
("float32", "float16"),
("float16", "float32"),
],
"int32": [
("int32", "int32"),
("int64", "int32"),
("int32", "int64"),
],
"int64": [
("int64", "int64"),
("int32", "int64"),
("int64", "int32"),
],
}
def get_dtype_map() -> dict:
"""Get mapping from string names to torch dtypes."""
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
"int8": torch.int8,
"int16": torch.int16,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
}
def get_scalar_promotion_pairs(
target_dtype: torch.dtype,
) -> list[tuple[torch.dtype, torch.dtype]]:
"""
Get promotion pairs for scalar operations.
Returns list of (dtype1, dtype2) tuples that promote to target_dtype.
"""
return (
[
(torch.float32, torch.float32),
(torch.float16, torch.float32),
(torch.float32, torch.float16),
(torch.int32, torch.float32),
(torch.float32, torch.int32),
]
if target_dtype == torch.float32
else [
(torch.float64, torch.float64),
(torch.float32, torch.float64),
(torch.float64, torch.float32),
]
if target_dtype == torch.float64
else [
(torch.int32, torch.int32),
(torch.int64, torch.int32),
(torch.int32, torch.int64),
]
if target_dtype == torch.int32
else [
(torch.int64, torch.int64),
(torch.int32, torch.int64),
(torch.int64, torch.int32),
]
if target_dtype == torch.int64
else [(target_dtype, target_dtype)]
)