mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] Encapsulate fuzzing and codegen logic into ops (#163547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163547 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
27164b6788
commit
0b75a16200
@ -196,6 +196,7 @@ exclude_patterns = [
|
|||||||
'tools/test/gen_operators_yaml_test.py',
|
'tools/test/gen_operators_yaml_test.py',
|
||||||
'tools/test/gen_oplist_test.py',
|
'tools/test/gen_oplist_test.py',
|
||||||
'tools/test/test_selective_build.py',
|
'tools/test/test_selective_build.py',
|
||||||
|
'tools/experimental/dynamic_shapes/torchfuzz/**',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
|
@ -1,49 +1,19 @@
|
|||||||
# mypy: ignore-errors
|
"""Torchfuzz package for generating and testing random PyTorch operations."""
|
||||||
|
|
||||||
"""
|
# Make key classes available at package level
|
||||||
PyTorch Operation Fuzzer for Dynamic Shapes Testing.
|
from .operators import get_operator, list_operators, register_operator
|
||||||
|
from .ops_fuzzer import fuzz_operation_graph, fuzz_spec, OperationGraph
|
||||||
This package provides comprehensive fuzzing tools for testing torch.compile
|
from .tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
||||||
with dynamic shapes and diverse operation patterns.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fuzzer import fuzz_and_execute, fuzz_operation_stack
|
|
||||||
from ops_fuzzer import fuzz_op, fuzz_spec, Operation
|
|
||||||
from tensor_fuzzer import (
|
|
||||||
fuzz_scalar,
|
|
||||||
fuzz_tensor,
|
|
||||||
fuzz_tensor_simple,
|
|
||||||
fuzz_tensor_size,
|
|
||||||
fuzz_torch_tensor_type,
|
|
||||||
fuzz_valid_stride,
|
|
||||||
ScalarSpec,
|
|
||||||
Spec,
|
|
||||||
TensorSpec,
|
|
||||||
test_fuzzing_tensors,
|
|
||||||
)
|
|
||||||
from visualize_stack import operation_stack_to_dot, visualize_operation_stack
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core fuzzing functionality
|
|
||||||
"fuzz_and_execute",
|
|
||||||
"fuzz_operation_stack",
|
|
||||||
"fuzz_op",
|
|
||||||
"fuzz_spec",
|
|
||||||
# Tensor fuzzing
|
|
||||||
"fuzz_tensor",
|
|
||||||
"fuzz_tensor_simple",
|
|
||||||
"fuzz_tensor_size",
|
|
||||||
"fuzz_torch_tensor_type",
|
|
||||||
"fuzz_valid_stride",
|
|
||||||
"fuzz_scalar",
|
|
||||||
"test_fuzzing_tensors",
|
|
||||||
# Data types and configuration
|
|
||||||
"Operation",
|
|
||||||
"TensorSpec",
|
"TensorSpec",
|
||||||
"ScalarSpec",
|
"ScalarSpec",
|
||||||
"Spec",
|
"Spec",
|
||||||
# Visualization
|
"OperationGraph",
|
||||||
"operation_stack_to_dot",
|
"fuzz_operation_graph",
|
||||||
"visualize_operation_stack",
|
"fuzz_spec",
|
||||||
|
"get_operator",
|
||||||
|
"register_operator",
|
||||||
|
"list_operators",
|
||||||
]
|
]
|
||||||
|
@ -9,11 +9,12 @@ from queue import Empty, Queue
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from ops_fuzzer import OperationGraph
|
|
||||||
from tensor_fuzzer import fuzz_scalar, fuzz_tensor_simple, ScalarSpec, Spec, TensorSpec
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from torchfuzz.operators import get_operator
|
||||||
|
from torchfuzz.ops_fuzzer import OperationGraph
|
||||||
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
||||||
|
|
||||||
|
|
||||||
def convert_graph_to_python_code(
|
def convert_graph_to_python_code(
|
||||||
operation_graph: OperationGraph, seed: Optional[int] = None
|
operation_graph: OperationGraph, seed: Optional[int] = None
|
||||||
@ -268,7 +269,7 @@ def generate_simple_operation_code(
|
|||||||
output_spec,
|
output_spec,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Generate code lines for executing a single operation (simplified version without arg_tracker).
|
Generate code lines for executing a single operation using class-based operators.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_var: Name of the output variable
|
output_var: Name of the output variable
|
||||||
@ -276,83 +277,15 @@ def generate_simple_operation_code(
|
|||||||
op_name: Name of the operation
|
op_name: Name of the operation
|
||||||
output_spec: Output specification for the operation
|
output_spec: Output specification for the operation
|
||||||
"""
|
"""
|
||||||
if op_name == "scalar_add":
|
# Try to get the operator from the registry
|
||||||
return [f"{output_var} = {input_vars[0]} + {input_vars[1]}"]
|
operator = get_operator(op_name)
|
||||||
|
|
||||||
elif op_name == "scalar_multiply":
|
|
||||||
return [f"{output_var} = {input_vars[0]} * {input_vars[1]}"]
|
|
||||||
|
|
||||||
elif op_name == "torch.ops.aten.item":
|
|
||||||
return [f"{output_var} = {input_vars[0]}.item()"]
|
|
||||||
|
|
||||||
elif op_name == "torch.ops.aten.add":
|
|
||||||
return [f"{output_var} = torch.ops.aten.add({input_vars[0]}, {input_vars[1]})"]
|
|
||||||
|
|
||||||
elif op_name == "torch.ops.aten.mul":
|
|
||||||
return [f"{output_var} = torch.ops.aten.mul({input_vars[0]}, {input_vars[1]})"]
|
|
||||||
|
|
||||||
elif op_name == "constant":
|
|
||||||
# Create constant by calling fuzzing functions during codegen with deterministic seed
|
|
||||||
# Use a deterministic seed based on the variable name to ensure reproducibility
|
|
||||||
var_seed = hash(output_var) % (2**31)
|
|
||||||
|
|
||||||
if isinstance(output_spec, ScalarSpec):
|
|
||||||
# Call fuzz_scalar during codegen and embed the result
|
|
||||||
actual_value = fuzz_scalar(output_spec, seed=var_seed)
|
|
||||||
|
|
||||||
# Format the value for embedding in code
|
|
||||||
if isinstance(actual_value, bool):
|
|
||||||
value_str = str(actual_value)
|
|
||||||
elif isinstance(actual_value, (int, float)):
|
|
||||||
value_str = repr(actual_value)
|
|
||||||
elif isinstance(actual_value, complex):
|
|
||||||
value_str = f"complex({actual_value.real}, {actual_value.imag})"
|
|
||||||
else:
|
|
||||||
value_str = repr(actual_value)
|
|
||||||
|
|
||||||
return [f"{output_var} = {value_str}"]
|
|
||||||
|
|
||||||
elif isinstance(output_spec, TensorSpec):
|
|
||||||
# Call fuzz_tensor_simple during codegen and embed the result
|
|
||||||
actual_tensor = fuzz_tensor_simple(
|
|
||||||
output_spec.size, output_spec.stride, output_spec.dtype, seed=var_seed
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert tensor to code representation
|
|
||||||
size_str = str(output_spec.size)
|
|
||||||
dtype_str = f"torch.{output_spec.dtype}".replace("torch.torch.", "torch.")
|
|
||||||
|
|
||||||
# Handle empty tensors (with 0 elements)
|
|
||||||
if actual_tensor.numel() == 0:
|
|
||||||
# For empty tensors, use a default fill value based on dtype
|
|
||||||
default_values = {
|
|
||||||
torch.float16: 0.0,
|
|
||||||
torch.float32: 0.0,
|
|
||||||
torch.float64: 0.0,
|
|
||||||
torch.bfloat16: 0.0,
|
|
||||||
torch.int8: 0,
|
|
||||||
torch.int16: 0,
|
|
||||||
torch.int32: 0,
|
|
||||||
torch.int64: 0,
|
|
||||||
torch.bool: False,
|
|
||||||
torch.complex64: 0.0,
|
|
||||||
torch.complex128: 0.0,
|
|
||||||
}
|
|
||||||
fill_value = default_values.get(output_spec.dtype, 0)
|
|
||||||
return [
|
|
||||||
f"{output_var} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# For non-empty tensors, use the first element as fill value
|
|
||||||
fill_value = actual_tensor.flatten()[0].item()
|
|
||||||
return [
|
|
||||||
f"{output_var} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
|
||||||
return [f"# Unknown output spec type for constant: {type(output_spec)}"]
|
|
||||||
|
|
||||||
|
if operator is not None:
|
||||||
|
# Use the class-based operator to generate code
|
||||||
|
code_line = operator.codegen(output_var, input_vars, output_spec)
|
||||||
|
return [code_line]
|
||||||
else:
|
else:
|
||||||
|
# Fallback for unknown operations
|
||||||
return [f"# Unknown operation: {op_name}"]
|
return [f"# Unknown operation: {op_name}"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,21 @@
|
|||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from codegen import convert_graph_to_python_code, execute_python_code
|
|
||||||
from ops_fuzzer import fuzz_operation_graph, fuzz_spec
|
# Add parent directory to path so we can import torchfuzz as a module
|
||||||
from visualize_graph import visualize_operation_graph
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parent_dir = os.path.dirname(current_dir)
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.insert(0, parent_dir)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torchfuzz.codegen import convert_graph_to_python_code, execute_python_code
|
||||||
|
from torchfuzz.ops_fuzzer import fuzz_operation_graph, fuzz_spec
|
||||||
|
from torchfuzz.visualize_graph import visualize_operation_graph
|
||||||
|
|
||||||
|
|
||||||
def fuzz_and_execute(
|
def fuzz_and_execute(
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
"""Torchfuzz operators module."""
|
||||||
|
|
||||||
|
from torchfuzz.operators.add import AddOperator
|
||||||
|
from torchfuzz.operators.arg import ArgOperator
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.operators.constant import ConstantOperator
|
||||||
|
from torchfuzz.operators.item import ItemOperator
|
||||||
|
from torchfuzz.operators.mul import MulOperator
|
||||||
|
from torchfuzz.operators.registry import get_operator, list_operators, register_operator
|
||||||
|
from torchfuzz.operators.scalar_add import ScalarAddOperator
|
||||||
|
from torchfuzz.operators.scalar_multiply import ScalarMultiplyOperator
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Operator",
|
||||||
|
"AddOperator",
|
||||||
|
"MulOperator",
|
||||||
|
"ItemOperator",
|
||||||
|
"ScalarAddOperator",
|
||||||
|
"ScalarMultiplyOperator",
|
||||||
|
"ConstantOperator",
|
||||||
|
"ArgOperator",
|
||||||
|
"get_operator",
|
||||||
|
"register_operator",
|
||||||
|
"list_operators",
|
||||||
|
]
|
71
tools/experimental/dynamic_shapes/torchfuzz/operators/add.py
Normal file
71
tools/experimental/dynamic_shapes/torchfuzz/operators/add.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
"""Add operator implementation."""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||||
|
from torchfuzz.type_promotion import (
|
||||||
|
get_dtype_map,
|
||||||
|
get_dtype_name,
|
||||||
|
get_promotion_table_for_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AddOperator(Operator):
|
||||||
|
"""Operator for element-wise addition."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("torch.ops.aten.add")
|
||||||
|
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Add can produce tensors but not scalars."""
|
||||||
|
return isinstance(output_spec, TensorSpec)
|
||||||
|
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Add operator supports variable number of inputs."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||||
|
"""Decompose tensor into input tensors for addition with type promotion."""
|
||||||
|
if not isinstance(output_spec, TensorSpec):
|
||||||
|
raise ValueError("AddOperator can only produce TensorSpec outputs")
|
||||||
|
|
||||||
|
# Use shared type promotion table
|
||||||
|
promotion_table = get_promotion_table_for_strings()
|
||||||
|
|
||||||
|
# If num_inputs > 2, promote left-to-right (e.g. (((a + b) + c) + d))
|
||||||
|
# For simplicity, we generate the first two with promotion, rest match output dtype
|
||||||
|
dtype_str = get_dtype_name(output_spec.dtype)
|
||||||
|
supported_types = promotion_table.get(dtype_str, [(dtype_str, dtype_str)])
|
||||||
|
|
||||||
|
# Pick a random promotion pattern for the first two inputs
|
||||||
|
if num_inputs >= 2:
|
||||||
|
dtypes = list(random.choice(supported_types))
|
||||||
|
# For >2 inputs, fill with output dtype
|
||||||
|
while len(dtypes) < num_inputs:
|
||||||
|
dtypes.append(dtype_str)
|
||||||
|
else:
|
||||||
|
dtypes = [dtype_str] * num_inputs
|
||||||
|
|
||||||
|
# Convert dtype strings back to torch dtypes
|
||||||
|
dtype_map = get_dtype_map()
|
||||||
|
|
||||||
|
return [
|
||||||
|
TensorSpec(
|
||||||
|
size=output_spec.size,
|
||||||
|
stride=output_spec.stride,
|
||||||
|
dtype=dtype_map.get(dt, output_spec.dtype),
|
||||||
|
)
|
||||||
|
for dt in dtypes
|
||||||
|
]
|
||||||
|
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for addition operation."""
|
||||||
|
if len(input_names) == 2:
|
||||||
|
return f"{output_name} = torch.ops.aten.add({input_names[0]}, {input_names[1]})"
|
||||||
|
else:
|
||||||
|
# Sum all input tensors
|
||||||
|
expr = " + ".join(input_names)
|
||||||
|
return f"{output_name} = {expr}"
|
31
tools/experimental/dynamic_shapes/torchfuzz/operators/arg.py
Normal file
31
tools/experimental/dynamic_shapes/torchfuzz/operators/arg.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""Arg operator implementation."""
|
||||||
|
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.tensor_fuzzer import Spec
|
||||||
|
|
||||||
|
|
||||||
|
class ArgOperator(Operator):
|
||||||
|
"""Operator for function arguments/parameters."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("arg")
|
||||||
|
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Arg can produce any type of output."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Arg operator does not require inputs."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 0) -> list[Spec]:
|
||||||
|
"""Arg requires no inputs."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for arg operation."""
|
||||||
|
# The actual argument name assignment will be handled separately
|
||||||
|
# in the codegen.py when processing arg operations
|
||||||
|
return f"# {output_name} will be assigned an argument value"
|
@ -0,0 +1,39 @@
|
|||||||
|
"""Base operator implementation."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from torchfuzz.tensor_fuzzer import Spec
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(ABC):
|
||||||
|
"""Base class for all operators in torchfuzz."""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
"""Initialize operator with name."""
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Check if this operator can produce the given output spec."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Check if this operator supports variable number of inputs."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||||
|
"""Decompose output spec into input specs."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for this operation."""
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation of the operator."""
|
||||||
|
return f"{self.__class__.__name__}({self.name})"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Repr representation of the operator."""
|
||||||
|
return self.__str__()
|
@ -0,0 +1,91 @@
|
|||||||
|
"""Constant operator implementation."""
|
||||||
|
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.tensor_fuzzer import (
|
||||||
|
fuzz_scalar,
|
||||||
|
fuzz_tensor_simple,
|
||||||
|
ScalarSpec,
|
||||||
|
Spec,
|
||||||
|
TensorSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantOperator(Operator):
|
||||||
|
"""Operator for generating constants."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("constant")
|
||||||
|
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Constant can produce any type of output."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Constant operator does not require inputs."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 0) -> list[Spec]:
|
||||||
|
"""Constant requires no inputs."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for constant creation."""
|
||||||
|
# Create constant by calling fuzzing functions during codegen with deterministic seed
|
||||||
|
# Use a deterministic seed based on the variable name to ensure reproducibility
|
||||||
|
var_seed = hash(output_name) % (2**31)
|
||||||
|
|
||||||
|
if isinstance(output_spec, ScalarSpec):
|
||||||
|
# Call fuzz_scalar during codegen and embed the result
|
||||||
|
actual_value = fuzz_scalar(output_spec, seed=var_seed)
|
||||||
|
|
||||||
|
# Format the value for embedding in code
|
||||||
|
if isinstance(actual_value, bool):
|
||||||
|
value_str = str(actual_value)
|
||||||
|
elif isinstance(actual_value, (int, float)):
|
||||||
|
value_str = repr(actual_value)
|
||||||
|
elif isinstance(actual_value, complex):
|
||||||
|
value_str = f"complex({actual_value.real}, {actual_value.imag})"
|
||||||
|
else:
|
||||||
|
value_str = repr(actual_value)
|
||||||
|
|
||||||
|
return f"{output_name} = {value_str}"
|
||||||
|
|
||||||
|
elif isinstance(output_spec, TensorSpec):
|
||||||
|
# Call fuzz_tensor_simple during codegen and embed the result
|
||||||
|
actual_tensor = fuzz_tensor_simple(
|
||||||
|
output_spec.size, output_spec.stride, output_spec.dtype, seed=var_seed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert tensor to code representation
|
||||||
|
size_str = str(output_spec.size)
|
||||||
|
dtype_str = f"torch.{output_spec.dtype}".replace("torch.torch.", "torch.")
|
||||||
|
|
||||||
|
# Handle empty tensors (with 0 elements)
|
||||||
|
if actual_tensor.numel() == 0:
|
||||||
|
# For empty tensors, use a default fill value based on dtype
|
||||||
|
import torch
|
||||||
|
|
||||||
|
default_values = {
|
||||||
|
torch.float16: 0.0,
|
||||||
|
torch.float32: 0.0,
|
||||||
|
torch.float64: 0.0,
|
||||||
|
torch.bfloat16: 0.0,
|
||||||
|
torch.int8: 0,
|
||||||
|
torch.int16: 0,
|
||||||
|
torch.int32: 0,
|
||||||
|
torch.int64: 0,
|
||||||
|
torch.bool: False,
|
||||||
|
torch.complex64: 0.0,
|
||||||
|
torch.complex128: 0.0,
|
||||||
|
}
|
||||||
|
fill_value = default_values.get(output_spec.dtype, 0)
|
||||||
|
return f"{output_name} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
||||||
|
else:
|
||||||
|
# For non-empty tensors, use the first element as fill value
|
||||||
|
fill_value = actual_tensor.flatten()[0].item()
|
||||||
|
return f"{output_name} = torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
||||||
|
|
||||||
|
else:
|
||||||
|
return f"# Unknown output spec type for constant: {type(output_spec)}"
|
@ -0,0 +1,39 @@
|
|||||||
|
"""Item operator implementation."""
|
||||||
|
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
||||||
|
|
||||||
|
|
||||||
|
class ItemOperator(Operator):
|
||||||
|
"""Operator for extracting a scalar from a tensor."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("torch.ops.aten.item")
|
||||||
|
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Item can only produce scalars."""
|
||||||
|
return isinstance(output_spec, ScalarSpec)
|
||||||
|
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Item operator does not support variable number of inputs."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
|
||||||
|
"""Decompose scalar into a single-element tensor for item operation."""
|
||||||
|
if not isinstance(output_spec, ScalarSpec):
|
||||||
|
raise ValueError("ItemOperator can only produce ScalarSpec outputs")
|
||||||
|
|
||||||
|
# Create a tensor spec that can produce a scalar via .item()
|
||||||
|
# Use a 1-D tensor with 1 element
|
||||||
|
tensor_spec = TensorSpec(size=(1,), stride=(1,), dtype=output_spec.dtype)
|
||||||
|
|
||||||
|
return [tensor_spec]
|
||||||
|
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for item operation."""
|
||||||
|
if len(input_names) != 1:
|
||||||
|
raise ValueError("ItemOperator requires exactly one input")
|
||||||
|
|
||||||
|
return f"{output_name} = {input_names[0]}.item()"
|
72
tools/experimental/dynamic_shapes/torchfuzz/operators/mul.py
Normal file
72
tools/experimental/dynamic_shapes/torchfuzz/operators/mul.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""Multiply operator implementation."""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||||
|
|
||||||
|
|
||||||
|
class MulOperator(Operator):
|
||||||
|
"""Operator for element-wise multiplication."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("torch.ops.aten.mul")
|
||||||
|
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Mul can produce tensors but not scalars."""
|
||||||
|
return isinstance(output_spec, TensorSpec)
|
||||||
|
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Mul operator supports variable number of inputs."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||||
|
"""Decompose tensor into input tensors for multiplication with type promotion."""
|
||||||
|
if not isinstance(output_spec, TensorSpec):
|
||||||
|
raise ValueError("MulOperator can only produce TensorSpec outputs")
|
||||||
|
|
||||||
|
# Use shared type promotion table
|
||||||
|
from torchfuzz.type_promotion import (
|
||||||
|
get_dtype_map,
|
||||||
|
get_dtype_name,
|
||||||
|
get_promotion_table_for_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
promotion_table = get_promotion_table_for_strings()
|
||||||
|
|
||||||
|
# If num_inputs > 2, promote left-to-right (e.g. (((a * b) * c) * d))
|
||||||
|
# For simplicity, we generate the first two with promotion, rest match output dtype
|
||||||
|
dtype_str = get_dtype_name(output_spec.dtype)
|
||||||
|
supported_types = promotion_table.get(dtype_str, [(dtype_str, dtype_str)])
|
||||||
|
|
||||||
|
# Pick a random promotion pattern for the first two inputs
|
||||||
|
if num_inputs >= 2:
|
||||||
|
dtypes = list(random.choice(supported_types))
|
||||||
|
# For >2 inputs, fill with output dtype
|
||||||
|
while len(dtypes) < num_inputs:
|
||||||
|
dtypes.append(dtype_str)
|
||||||
|
else:
|
||||||
|
dtypes = [dtype_str] * num_inputs
|
||||||
|
|
||||||
|
# Convert dtype strings back to torch dtypes
|
||||||
|
dtype_map = get_dtype_map()
|
||||||
|
|
||||||
|
return [
|
||||||
|
TensorSpec(
|
||||||
|
size=output_spec.size,
|
||||||
|
stride=output_spec.stride,
|
||||||
|
dtype=dtype_map.get(dt, output_spec.dtype),
|
||||||
|
)
|
||||||
|
for dt in dtypes
|
||||||
|
]
|
||||||
|
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for multiplication operation."""
|
||||||
|
if len(input_names) == 2:
|
||||||
|
return f"{output_name} = torch.ops.aten.mul({input_names[0]}, {input_names[1]})"
|
||||||
|
else:
|
||||||
|
# Multiply all input tensors
|
||||||
|
expr = " * ".join(input_names)
|
||||||
|
return f"{output_name} = {expr}"
|
@ -0,0 +1,62 @@
|
|||||||
|
"""Operator registry for mapping operation names to operator instances."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torchfuzz.operators.add import AddOperator
|
||||||
|
from torchfuzz.operators.arg import ArgOperator
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.operators.constant import ConstantOperator
|
||||||
|
from torchfuzz.operators.item import ItemOperator
|
||||||
|
from torchfuzz.operators.mul import MulOperator
|
||||||
|
from torchfuzz.operators.scalar_add import ScalarAddOperator
|
||||||
|
from torchfuzz.operators.scalar_multiply import ScalarMultiplyOperator
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorRegistry:
|
||||||
|
"""Registry for managing operator instances."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the registry with default operators."""
|
||||||
|
self._operators: dict[str, Operator] = {}
|
||||||
|
self._register_default_operators()
|
||||||
|
|
||||||
|
def _register_default_operators(self):
|
||||||
|
"""Register the default set of operators."""
|
||||||
|
self.register(AddOperator())
|
||||||
|
self.register(MulOperator())
|
||||||
|
self.register(ItemOperator())
|
||||||
|
self.register(ScalarAddOperator())
|
||||||
|
self.register(ScalarMultiplyOperator())
|
||||||
|
self.register(ConstantOperator())
|
||||||
|
self.register(ArgOperator())
|
||||||
|
|
||||||
|
def register(self, operator: Operator):
|
||||||
|
"""Register an operator in the registry."""
|
||||||
|
self._operators[operator.name] = operator
|
||||||
|
|
||||||
|
def get(self, op_name: str) -> Optional[Operator]:
|
||||||
|
"""Get an operator by name."""
|
||||||
|
return self._operators.get(op_name)
|
||||||
|
|
||||||
|
def list_operators(self) -> dict[str, Operator]:
|
||||||
|
"""List all registered operators."""
|
||||||
|
return self._operators.copy()
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
_global_registry = OperatorRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
def get_operator(op_name: str) -> Optional[Operator]:
|
||||||
|
"""Get an operator from the global registry."""
|
||||||
|
return _global_registry.get(op_name)
|
||||||
|
|
||||||
|
|
||||||
|
def register_operator(operator: Operator):
|
||||||
|
"""Register an operator in the global registry."""
|
||||||
|
_global_registry.register(operator)
|
||||||
|
|
||||||
|
|
||||||
|
def list_operators() -> dict[str, Operator]:
|
||||||
|
"""List all operators in the global registry."""
|
||||||
|
return _global_registry.list_operators()
|
@ -0,0 +1,43 @@
|
|||||||
|
"""Scalar add operator implementation."""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarAddOperator(Operator):
|
||||||
|
"""Operator for adding two scalars."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("scalar_add")
|
||||||
|
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Scalar add can only produce scalars."""
|
||||||
|
return isinstance(output_spec, ScalarSpec)
|
||||||
|
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Scalar add operator does not support variable number of inputs."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||||
|
"""Decompose scalar into input scalars for addition with type promotion."""
|
||||||
|
if not isinstance(output_spec, ScalarSpec):
|
||||||
|
raise ValueError("ScalarAddOperator can only produce ScalarSpec outputs")
|
||||||
|
|
||||||
|
# Use shared type promotion utility
|
||||||
|
from torchfuzz.type_promotion import get_scalar_promotion_pairs
|
||||||
|
|
||||||
|
supported_types = get_scalar_promotion_pairs(output_spec.dtype)
|
||||||
|
dtypes = random.choice(supported_types)
|
||||||
|
|
||||||
|
return [ScalarSpec(dtype=dtypes[0]), ScalarSpec(dtype=dtypes[1])]
|
||||||
|
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for scalar addition operation."""
|
||||||
|
if len(input_names) != 2:
|
||||||
|
raise ValueError("ScalarAddOperator requires exactly two inputs")
|
||||||
|
|
||||||
|
return f"{output_name} = {input_names[0]} + {input_names[1]}"
|
@ -0,0 +1,45 @@
|
|||||||
|
"""Scalar multiply operator implementation."""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from torchfuzz.operators.base import Operator
|
||||||
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarMultiplyOperator(Operator):
|
||||||
|
"""Operator for multiplying two scalars."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("scalar_multiply")
|
||||||
|
|
||||||
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
|
"""Scalar multiply can only produce scalars."""
|
||||||
|
return isinstance(output_spec, ScalarSpec)
|
||||||
|
|
||||||
|
def supports_variable_inputs(self) -> bool:
|
||||||
|
"""Scalar multiply operator does not support variable number of inputs."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def decompose(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||||
|
"""Decompose scalar into input scalars for multiplication with type promotion."""
|
||||||
|
if not isinstance(output_spec, ScalarSpec):
|
||||||
|
raise ValueError(
|
||||||
|
"ScalarMultiplyOperator can only produce ScalarSpec outputs"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use shared type promotion utility
|
||||||
|
from torchfuzz.type_promotion import get_scalar_promotion_pairs
|
||||||
|
|
||||||
|
supported_types = get_scalar_promotion_pairs(output_spec.dtype)
|
||||||
|
dtypes = random.choice(supported_types)
|
||||||
|
|
||||||
|
return [ScalarSpec(dtype=dtypes[0]), ScalarSpec(dtype=dtypes[1])]
|
||||||
|
|
||||||
|
def codegen(
|
||||||
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||||
|
) -> str:
|
||||||
|
"""Generate code for scalar multiplication operation."""
|
||||||
|
if len(input_names) != 2:
|
||||||
|
raise ValueError("ScalarMultiplyOperator requires exactly two inputs")
|
||||||
|
|
||||||
|
return f"{output_name} = {input_names[0]} * {input_names[1]}"
|
@ -4,7 +4,10 @@ import random
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from tensor_fuzzer import (
|
import torch
|
||||||
|
|
||||||
|
from torchfuzz.operators import get_operator, list_operators
|
||||||
|
from torchfuzz.tensor_fuzzer import (
|
||||||
fuzz_tensor_size,
|
fuzz_tensor_size,
|
||||||
fuzz_torch_tensor_type,
|
fuzz_torch_tensor_type,
|
||||||
fuzz_valid_stride,
|
fuzz_valid_stride,
|
||||||
@ -14,8 +17,6 @@ from tensor_fuzzer import (
|
|||||||
TensorSpec,
|
TensorSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OperationNode:
|
class OperationNode:
|
||||||
@ -172,11 +173,7 @@ def fuzz_spec() -> Spec:
|
|||||||
def fuzz_op(target_spec: Spec, depth, stack_size) -> tuple[str, list[Spec]]:
|
def fuzz_op(target_spec: Spec, depth, stack_size) -> tuple[str, list[Spec]]:
|
||||||
"""
|
"""
|
||||||
Given an output specification, returns an operation that can
|
Given an output specification, returns an operation that can
|
||||||
produce a tensor with that layout.
|
produce a tensor with that layout using the operator class system.
|
||||||
|
|
||||||
Supports:
|
|
||||||
- For scalars: scalar_add, scalar_multiply, item, constant, arg
|
|
||||||
- For tensors: aten.add, aten.mul, constant, arg
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_spec: Desired output specification (TensorSpec or ScalarSpec)
|
target_spec: Desired output specification (TensorSpec or ScalarSpec)
|
||||||
@ -188,239 +185,69 @@ def fuzz_op(target_spec: Spec, depth, stack_size) -> tuple[str, list[Spec]]:
|
|||||||
Tuple of (operation_name, list_of_argument_specs) where each argument spec
|
Tuple of (operation_name, list_of_argument_specs) where each argument spec
|
||||||
describes the layout requirements for the operation's inputs
|
describes the layout requirements for the operation's inputs
|
||||||
"""
|
"""
|
||||||
if isinstance(target_spec, ScalarSpec):
|
# Get all available operators
|
||||||
if target_spec.constant is not None:
|
available_operators = list_operators()
|
||||||
# At depth 0, only allow constant operation
|
|
||||||
return _get_constant_args_specs(target_spec)
|
# Filter operators that can produce the target spec
|
||||||
if depth == 0:
|
compatible_ops = []
|
||||||
# At depth 0, only allow leaf operations
|
for op_name, operator in available_operators.items():
|
||||||
ops = ["constant", "arg"]
|
if operator.can_produce(target_spec):
|
||||||
chosen_op = random.choice(ops)
|
compatible_ops.append((op_name, operator))
|
||||||
|
|
||||||
|
if not compatible_ops:
|
||||||
|
raise ValueError(f"No operators available that can produce {target_spec}")
|
||||||
|
|
||||||
|
# Categorize operators into leaf and non-leaf
|
||||||
|
leaf_ops = []
|
||||||
|
non_leaf_ops = []
|
||||||
|
|
||||||
|
for op_name, operator in compatible_ops:
|
||||||
|
if op_name in ["constant", "arg"] or op_name.startswith("arg_"):
|
||||||
|
leaf_ops.append((op_name, operator))
|
||||||
else:
|
else:
|
||||||
# At higher depths, allow all scalar operations
|
non_leaf_ops.append((op_name, operator))
|
||||||
non_leaf_ops = ["scalar_add", "scalar_multiply", "torch.ops.aten.item"]
|
|
||||||
leaf_ops = ["constant", "arg"]
|
|
||||||
|
|
||||||
# Reduce probability of leaf operations when stack_size < 10
|
# Choose operation based on depth and stack size constraints
|
||||||
if stack_size < 10 or depth > 7:
|
if depth == 0:
|
||||||
# 80% chance of non-leaf, 20% chance of leaf
|
# At depth 0, only allow leaf operations
|
||||||
if random.random() < 0.8:
|
if not leaf_ops:
|
||||||
chosen_op = random.choice(non_leaf_ops)
|
# If no leaf ops can produce this spec, fallback to arg
|
||||||
else:
|
|
||||||
chosen_op = random.choice(leaf_ops)
|
|
||||||
else:
|
|
||||||
# Normal probability distribution
|
|
||||||
all_ops = non_leaf_ops + leaf_ops
|
|
||||||
chosen_op = random.choice(all_ops)
|
|
||||||
|
|
||||||
if chosen_op == "scalar_add":
|
|
||||||
return _get_scalar_add_args_specs(target_spec)
|
|
||||||
elif chosen_op == "scalar_multiply":
|
|
||||||
return _get_scalar_multiply_args_specs(target_spec)
|
|
||||||
elif chosen_op == "torch.ops.aten.item":
|
|
||||||
return _get_item_args_specs(target_spec)
|
|
||||||
elif chosen_op == "constant":
|
|
||||||
return _get_constant_args_specs(target_spec)
|
|
||||||
else: # arg
|
|
||||||
return _get_arg_args_specs(target_spec)
|
return _get_arg_args_specs(target_spec)
|
||||||
|
chosen_op_name, chosen_operator = random.choice(leaf_ops)
|
||||||
elif isinstance(target_spec, TensorSpec):
|
|
||||||
if depth == 0:
|
|
||||||
# At depth 0, only allow leaf operations
|
|
||||||
ops = ["arg"]
|
|
||||||
chosen_op = random.choice(ops)
|
|
||||||
else:
|
|
||||||
# At higher depths, allow all tensor operations
|
|
||||||
non_leaf_ops = [
|
|
||||||
"torch.ops.aten.add",
|
|
||||||
"torch.ops.aten.mul",
|
|
||||||
]
|
|
||||||
|
|
||||||
leaf_ops = ["arg"]
|
|
||||||
|
|
||||||
# Reduce probability of leaf operations when stack_size < 10
|
|
||||||
if stack_size < 10:
|
|
||||||
# 80% chance of non-leaf, 20% chance of leaf
|
|
||||||
if random.random() < 0.8:
|
|
||||||
chosen_op = random.choice(non_leaf_ops)
|
|
||||||
else:
|
|
||||||
chosen_op = random.choice(leaf_ops)
|
|
||||||
else:
|
|
||||||
# Normal probability distribution
|
|
||||||
all_ops = non_leaf_ops + leaf_ops
|
|
||||||
chosen_op = random.choice(all_ops)
|
|
||||||
|
|
||||||
if chosen_op == "torch.ops.aten.add":
|
|
||||||
return _get_aten_add_args_specs(target_spec)
|
|
||||||
elif chosen_op == "torch.ops.aten.mul":
|
|
||||||
return _get_aten_mul_args_specs(target_spec)
|
|
||||||
elif chosen_op == "constant":
|
|
||||||
return _get_constant_args_specs(target_spec)
|
|
||||||
else: # arg
|
|
||||||
return _get_arg_args_specs(target_spec)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown target spec type: {type(target_spec)}")
|
# At higher depths, choose between leaf and non-leaf operations
|
||||||
|
# Reduce probability of leaf operations when stack_size < 10
|
||||||
|
if (stack_size < 10 or depth > 7) and non_leaf_ops:
|
||||||
def _get_scalar_add_args_specs(target_spec: ScalarSpec) -> tuple[str, list[Spec]]:
|
# 80% chance of non-leaf, 20% chance of leaf
|
||||||
"""Get argument specifications for scalar_add operation using type promotion rules."""
|
if random.random() < 0.8:
|
||||||
# Use PyTorch's implicit type promotion rules to generate diverse input types
|
chosen_op_name, chosen_operator = random.choice(non_leaf_ops)
|
||||||
arg_specs = _get_promoted_scalar_args(target_spec.dtype)
|
|
||||||
return "scalar_add", arg_specs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_scalar_multiply_args_specs(target_spec: ScalarSpec) -> tuple[str, list[Spec]]:
|
|
||||||
"""Get argument specifications for scalar_multiply operation using type promotion rules."""
|
|
||||||
# Use PyTorch's implicit type promotion rules to generate diverse input types
|
|
||||||
arg_specs = _get_promoted_scalar_args(target_spec.dtype)
|
|
||||||
return "scalar_multiply", arg_specs
|
|
||||||
|
|
||||||
|
|
||||||
# Define promotion chains - types that can promote to the target
|
|
||||||
# PyTorch promotion hierarchy (simplified):
|
|
||||||
# - bool < int8 < int16 < int32 < int64 < float16 < float32 < float64 < complex64 < complex128
|
|
||||||
# - uint types have limited promotion support
|
|
||||||
_PROMOTION_CHAINS = {
|
|
||||||
torch.bool: [torch.bool],
|
|
||||||
torch.int8: [torch.bool, torch.int8],
|
|
||||||
torch.int16: [torch.bool, torch.int8, torch.int16],
|
|
||||||
torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32],
|
|
||||||
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64],
|
|
||||||
torch.float16: [
|
|
||||||
torch.bool,
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int32,
|
|
||||||
torch.int64,
|
|
||||||
torch.float16,
|
|
||||||
],
|
|
||||||
torch.float32: [
|
|
||||||
torch.bool,
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int32,
|
|
||||||
torch.int64,
|
|
||||||
torch.float16,
|
|
||||||
torch.float32,
|
|
||||||
],
|
|
||||||
torch.float64: [
|
|
||||||
torch.bool,
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int32,
|
|
||||||
torch.int64,
|
|
||||||
torch.float16,
|
|
||||||
torch.float32,
|
|
||||||
torch.float64,
|
|
||||||
],
|
|
||||||
torch.complex64: [
|
|
||||||
torch.bool,
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int32,
|
|
||||||
torch.int64,
|
|
||||||
torch.float16,
|
|
||||||
torch.float32,
|
|
||||||
torch.complex64,
|
|
||||||
],
|
|
||||||
torch.complex128: [
|
|
||||||
torch.bool,
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int32,
|
|
||||||
torch.int64,
|
|
||||||
torch.float16,
|
|
||||||
torch.float32,
|
|
||||||
torch.float64,
|
|
||||||
torch.complex64,
|
|
||||||
torch.complex128,
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_promoted_dtypes(target_dtype: torch.dtype) -> list[torch.dtype]:
|
|
||||||
"""
|
|
||||||
Generate two dtypes that will promote to target_dtype via PyTorch's type promotion rules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Get compatible input types for the target dtype
|
|
||||||
compatible_types = _PROMOTION_CHAINS.get(target_dtype, [target_dtype])
|
|
||||||
|
|
||||||
# Strategy: Choose between same type or mixed promotion
|
|
||||||
strategies = ["same_type", "mixed_promotion"]
|
|
||||||
strategy = random.choice(strategies)
|
|
||||||
|
|
||||||
if strategy == "same_type":
|
|
||||||
# Both args same type as target
|
|
||||||
return [target_dtype, target_dtype]
|
|
||||||
|
|
||||||
else: # mixed_promotion
|
|
||||||
# Mixed types where the result will promote to target_dtype
|
|
||||||
lower_types = compatible_types[:-1] # All except the last (target_dtype)
|
|
||||||
|
|
||||||
if lower_types:
|
|
||||||
# One arg is target_dtype, one is lower (will promote to target)
|
|
||||||
lower_dtype = random.choice(lower_types)
|
|
||||||
if random.random() < 0.5:
|
|
||||||
return [target_dtype, lower_dtype]
|
|
||||||
else:
|
else:
|
||||||
return [lower_dtype, target_dtype]
|
chosen_op_name, chosen_operator = (
|
||||||
|
random.choice(leaf_ops) if leaf_ops else random.choice(non_leaf_ops)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback to same type if no lower types available
|
# Normal probability distribution
|
||||||
return [target_dtype, target_dtype]
|
all_ops = non_leaf_ops + leaf_ops
|
||||||
|
chosen_op_name, chosen_operator = (
|
||||||
|
random.choice(all_ops) if all_ops else ("arg", get_operator("arg"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the operator to decompose the target spec into input specs
|
||||||
def _get_promoted_scalar_args(target_dtype: torch.dtype) -> list[Spec]:
|
try:
|
||||||
"""
|
if chosen_op_name.startswith("arg_"):
|
||||||
Generate two argument specs that will promote to target_dtype via PyTorch's type promotion rules.
|
# Handle special arg_ operations
|
||||||
"""
|
return chosen_op_name, []
|
||||||
arg_dtypes = _get_promoted_dtypes(target_dtype)
|
elif chosen_op_name in ["constant", "arg"]:
|
||||||
|
# Handle leaf operations
|
||||||
# For ScalarSpec output, both inputs must be ScalarSpec
|
return chosen_op_name, []
|
||||||
# (mixing with 0-D TensorSpec would produce 0-D TensorSpec output)
|
else:
|
||||||
return [ScalarSpec(arg_dtypes[0]), ScalarSpec(arg_dtypes[1])]
|
# Use the operator's decompose method
|
||||||
|
input_specs = chosen_operator.decompose(target_spec)
|
||||||
|
return chosen_op_name, input_specs
|
||||||
def _get_item_args_specs(target_spec: ScalarSpec) -> tuple[str, list[Spec]]:
|
except Exception as e:
|
||||||
"""Get argument specifications for torch.ops.aten.item operation."""
|
# Fallback to arg if decomposition fails
|
||||||
# torch.ops.aten.item: tensor -> scalar (extract single element)
|
print(f"Warning: operator {chosen_op_name} decomposition failed: {e}")
|
||||||
# Create a tensor spec that can produce a scalar via .item()
|
return _get_arg_args_specs(target_spec)
|
||||||
tensor_spec = TensorSpec(
|
|
||||||
size=(1,), stride=(1,), dtype=target_spec.dtype
|
|
||||||
) # 1-D tensor with 1 element
|
|
||||||
arg_specs: list[Spec] = [tensor_spec]
|
|
||||||
return "torch.ops.aten.item", arg_specs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_aten_add_args_specs(target_spec: TensorSpec) -> tuple[str, list[Spec]]:
|
|
||||||
"""Get argument specifications for torch.ops.aten.add operation using type promotion rules."""
|
|
||||||
# Use promotion rules to generate diverse tensor input types
|
|
||||||
arg_dtypes = _get_promoted_dtypes(target_spec.dtype)
|
|
||||||
|
|
||||||
arg_specs: list[Spec] = [
|
|
||||||
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[0]),
|
|
||||||
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[1]),
|
|
||||||
]
|
|
||||||
return "torch.ops.aten.add", arg_specs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_aten_mul_args_specs(target_spec: TensorSpec) -> tuple[str, list[Spec]]:
|
|
||||||
"""Get argument specifications for torch.ops.aten.mul operation using type promotion rules."""
|
|
||||||
# Use promotion rules to generate diverse tensor input types
|
|
||||||
arg_dtypes = _get_promoted_dtypes(target_spec.dtype)
|
|
||||||
|
|
||||||
arg_specs: list[Spec] = [
|
|
||||||
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[0]),
|
|
||||||
TensorSpec(target_spec.size, target_spec.stride, arg_dtypes[1]),
|
|
||||||
]
|
|
||||||
return "torch.ops.aten.mul", arg_specs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_constant_args_specs(target_spec: Spec) -> tuple[str, list[Spec]]:
|
|
||||||
"""Get argument specifications for constant operation."""
|
|
||||||
# Constant operation takes no arguments - generates a fixed constant value/tensor
|
|
||||||
return "constant", []
|
|
||||||
|
|
||||||
|
|
||||||
# Global counter for generating unique argument IDs
|
# Global counter for generating unique argument IDs
|
||||||
|
194
tools/experimental/dynamic_shapes/torchfuzz/type_promotion.py
Normal file
194
tools/experimental/dynamic_shapes/torchfuzz/type_promotion.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
"""Type promotion utilities for torchfuzz operators."""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Define promotion chains - types that can promote to the target
|
||||||
|
# PyTorch promotion hierarchy (simplified):
|
||||||
|
# - bool < int8 < int16 < int32 < int64 < float16 < float32 < float64 < complex64 < complex128
|
||||||
|
# - uint types have limited promotion support
|
||||||
|
PROMOTION_CHAINS = {
|
||||||
|
torch.bool: [torch.bool],
|
||||||
|
torch.int8: [torch.bool, torch.int8],
|
||||||
|
torch.int16: [torch.bool, torch.int8, torch.int16],
|
||||||
|
torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32],
|
||||||
|
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64],
|
||||||
|
torch.float16: [
|
||||||
|
torch.bool,
|
||||||
|
torch.int8,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
torch.float16,
|
||||||
|
],
|
||||||
|
torch.float32: [
|
||||||
|
torch.bool,
|
||||||
|
torch.int8,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
torch.float16,
|
||||||
|
torch.float32,
|
||||||
|
],
|
||||||
|
torch.float64: [
|
||||||
|
torch.bool,
|
||||||
|
torch.int8,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
torch.float16,
|
||||||
|
torch.float32,
|
||||||
|
torch.float64,
|
||||||
|
],
|
||||||
|
torch.complex64: [
|
||||||
|
torch.bool,
|
||||||
|
torch.int8,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
torch.float16,
|
||||||
|
torch.float32,
|
||||||
|
torch.complex64,
|
||||||
|
],
|
||||||
|
torch.complex128: [
|
||||||
|
torch.bool,
|
||||||
|
torch.int8,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
torch.float16,
|
||||||
|
torch.float32,
|
||||||
|
torch.float64,
|
||||||
|
torch.complex64,
|
||||||
|
torch.complex128,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_promoted_dtypes(target_dtype: torch.dtype) -> list[torch.dtype]:
|
||||||
|
"""
|
||||||
|
Generate two dtypes that will promote to target_dtype via PyTorch's type promotion rules.
|
||||||
|
"""
|
||||||
|
# Get compatible input types for the target dtype
|
||||||
|
compatible_types = PROMOTION_CHAINS.get(target_dtype, [target_dtype])
|
||||||
|
|
||||||
|
# Strategy: Choose between same type or mixed promotion
|
||||||
|
strategies = ["same_type", "mixed_promotion"]
|
||||||
|
strategy = random.choice(strategies)
|
||||||
|
|
||||||
|
if strategy == "same_type":
|
||||||
|
# Both args same type as target
|
||||||
|
return [target_dtype, target_dtype]
|
||||||
|
|
||||||
|
else: # mixed_promotion
|
||||||
|
# Mixed types where the result will promote to target_dtype
|
||||||
|
lower_types = compatible_types[:-1] # All except the last (target_dtype)
|
||||||
|
|
||||||
|
if lower_types:
|
||||||
|
# One arg is target_dtype, one is lower (will promote to target)
|
||||||
|
lower_dtype = random.choice(lower_types)
|
||||||
|
if random.random() < 0.5:
|
||||||
|
return [target_dtype, lower_dtype]
|
||||||
|
else:
|
||||||
|
return [lower_dtype, target_dtype]
|
||||||
|
else:
|
||||||
|
# Fallback to same type if no lower types available
|
||||||
|
return [target_dtype, target_dtype]
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype_name(dtype: torch.dtype) -> str:
|
||||||
|
"""Get string name for a torch dtype."""
|
||||||
|
return str(dtype).split(".")[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_promotion_table_for_strings() -> dict:
|
||||||
|
"""
|
||||||
|
Get promotion table using string dtype names for backward compatibility.
|
||||||
|
Returns dictionary mapping output dtype string to possible input dtype string pairs.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"float32": [
|
||||||
|
("float32", "float32"),
|
||||||
|
("bfloat16", "float32"),
|
||||||
|
("float32", "bfloat16"),
|
||||||
|
("float16", "float32"),
|
||||||
|
("float32", "float16"),
|
||||||
|
],
|
||||||
|
"bfloat16": [
|
||||||
|
("bfloat16", "bfloat16"),
|
||||||
|
("float32", "bfloat16"),
|
||||||
|
("bfloat16", "float32"),
|
||||||
|
],
|
||||||
|
"float16": [
|
||||||
|
("float16", "float16"),
|
||||||
|
("float32", "float16"),
|
||||||
|
("float16", "float32"),
|
||||||
|
],
|
||||||
|
"int32": [
|
||||||
|
("int32", "int32"),
|
||||||
|
("int64", "int32"),
|
||||||
|
("int32", "int64"),
|
||||||
|
],
|
||||||
|
"int64": [
|
||||||
|
("int64", "int64"),
|
||||||
|
("int32", "int64"),
|
||||||
|
("int64", "int32"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype_map() -> dict:
|
||||||
|
"""Get mapping from string names to torch dtypes."""
|
||||||
|
return {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"bool": torch.bool,
|
||||||
|
"int8": torch.int8,
|
||||||
|
"int16": torch.int16,
|
||||||
|
"float64": torch.float64,
|
||||||
|
"complex64": torch.complex64,
|
||||||
|
"complex128": torch.complex128,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_scalar_promotion_pairs(
|
||||||
|
target_dtype: torch.dtype,
|
||||||
|
) -> list[tuple[torch.dtype, torch.dtype]]:
|
||||||
|
"""
|
||||||
|
Get promotion pairs for scalar operations.
|
||||||
|
Returns list of (dtype1, dtype2) tuples that promote to target_dtype.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
[
|
||||||
|
(torch.float32, torch.float32),
|
||||||
|
(torch.float16, torch.float32),
|
||||||
|
(torch.float32, torch.float16),
|
||||||
|
(torch.int32, torch.float32),
|
||||||
|
(torch.float32, torch.int32),
|
||||||
|
]
|
||||||
|
if target_dtype == torch.float32
|
||||||
|
else [
|
||||||
|
(torch.float64, torch.float64),
|
||||||
|
(torch.float32, torch.float64),
|
||||||
|
(torch.float64, torch.float32),
|
||||||
|
]
|
||||||
|
if target_dtype == torch.float64
|
||||||
|
else [
|
||||||
|
(torch.int32, torch.int32),
|
||||||
|
(torch.int64, torch.int32),
|
||||||
|
(torch.int32, torch.int64),
|
||||||
|
]
|
||||||
|
if target_dtype == torch.int32
|
||||||
|
else [
|
||||||
|
(torch.int64, torch.int64),
|
||||||
|
(torch.int32, torch.int64),
|
||||||
|
(torch.int64, torch.int32),
|
||||||
|
]
|
||||||
|
if target_dtype == torch.int64
|
||||||
|
else [(target_dtype, target_dtype)]
|
||||||
|
)
|
Reference in New Issue
Block a user