[torchfuzz] add support for operator weights (#164649)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164649
Approved by: https://github.com/pianpwk
ghstack dependencies: #164432, #164434, #164514, #164646, #164647
This commit is contained in:
bobrenjc93
2025-10-05 23:01:58 -07:00
committed by PyTorch MergeBot
parent ded099ecbf
commit 5fe7f29b9e
10 changed files with 296 additions and 29 deletions

View File

@ -123,6 +123,28 @@ python fuzzer.py --seed 42 --log-level DEBUG --max-depth 5
| `--verbose` | Print detailed output for all runs | `--verbose` |
| `--template NAME` | Template to use for all runs | `--template default` |
## Restricting supported ops and weighting examples
You can restrict the fuzzer to a specific set of fully-qualified torch ops and optionally weight them to bias sampling.
- Restrict to only torch.add and torch.matmul (equal likelihood):
```bash
python fuzzer.py --seed 42 \
--supported-ops "torch.add,torch.matmul"
```
- Restrict to only torch.add and torch.matmul, and make matmul 5x more likely than add:
```bash
python fuzzer.py --seed 42 \
--supported-ops "torch.add,torch.matmul=5"
```
Notes:
- Use fully-qualified torch op names (e.g., torch.matmul, torch.nn.functional.rms_norm).
- Weights must be > 0. If both --supported-ops and --op-weights specify a weight for the same op, the value from --supported-ops takes precedence.
## Architecture
### Core Components

View File

@ -20,11 +20,42 @@ from torchfuzz.runner import ProgramRunner
from torchfuzz.visualize_graph import visualize_operation_graph
def _parse_supported_ops_with_weights(spec: str) -> tuple[list[str], dict[str, float]]:
"""Parse --supported-ops string.
Format: comma-separated fully-qualified torch ops, each optionally with =weight.
Example: "torch.matmul=5,torch.nn.functional.rms_norm=5,torch.add"
Returns (ops_list, weights_dict)
"""
ops: list[str] = []
weights: dict[str, float] = {}
if not spec:
return ops, weights
for entry in spec.split(","):
entry = entry.strip()
if not entry:
continue
if "=" in entry:
name, w = entry.split("=", 1)
name = name.strip()
try:
weight = float(w.strip())
except ValueError:
continue
ops.append(name)
weights[name] = weight
else:
ops.append(entry)
return ops, weights
def fuzz_and_execute(
seed: Optional[int] = None,
max_depth: Optional[int] = None,
log_at_faluire: bool = False,
template: str = "default",
supported_ops: Optional[list[str]] = None,
op_weights: Optional[dict[str, float]] = None,
) -> None:
"""
Generate a fuzzed operation stack, convert it to Python code, and execute it.
@ -113,6 +144,12 @@ def fuzz_and_execute(
logger.debug("⏱️ Step 1: Generating target spec...")
start_time = time.time()
target_spec = fuzz_spec(template)
# Apply user-specified operator weights (if provided)
if op_weights:
from torchfuzz.operators import set_operator_weights
set_operator_weights(op_weights)
logger.debug(
" Completed in %.3fs - %s", time.time() - start_time, target_spec
)
@ -120,7 +157,11 @@ def fuzz_and_execute(
logger.debug("⏱️ Step 2: Generating operation graph...")
start_time = time.time()
operation_graph = fuzz_operation_graph(
target_spec, max_depth=max_depth, seed=seed, template=template
target_spec,
max_depth=max_depth,
seed=seed,
template=template,
supported_ops=supported_ops,
)
# Extract and print operation statistics
@ -226,6 +267,15 @@ if __name__ == "__main__":
default="default",
help="Template to use for code generation (default: default)",
)
parser.add_argument(
"--supported-ops",
type=str,
help=(
"Comma-separated fully-qualified torch ops to allow, each optionally with =weight. "
"Examples: 'torch.matmul,torch.nn.functional.rms_norm' or "
"'torch.matmul=5,torch.nn.functional.rms_norm=5'. Overrides template supported ops."
),
)
# Multi-process fuzzing arguments
parser.add_argument(
@ -272,8 +322,20 @@ if __name__ == "__main__":
if args.seed is not None or args.single:
# Single seed execution mode
print("Running single fuzz_and_execute...")
# Parse supported ops and any inline weights from that flag
parsed_supported_ops: Optional[list[str]] = None
parsed_weights: dict[str, float] = {}
if args.supported_ops:
parsed_supported_ops, parsed_weights = _parse_supported_ops_with_weights(
args.supported_ops
)
fuzz_and_execute(
seed=args.seed, max_depth=args.max_depth, template=args.template
seed=args.seed,
max_depth=args.max_depth,
template=args.template,
supported_ops=parsed_supported_ops,
op_weights=(parsed_weights if parsed_weights else None),
)
elif args.start is not None or args.count is not None:
# Multi-process fuzzing mode
@ -305,6 +367,7 @@ if __name__ == "__main__":
seed_count=args.count,
verbose=args.verbose,
template=args.template,
supported_ops=args.supported_ops,
)
except Exception as e:
print(f"❌ Unexpected error: {str(e)}")

View File

@ -90,13 +90,18 @@ def is_ignored_output(output: str) -> int:
return -1
def run_fuzzer_with_seed(seed: int, template: str = "default") -> FuzzerResult:
def run_fuzzer_with_seed(
seed: int,
template: str = "default",
supported_ops: Optional[str] = None,
) -> FuzzerResult:
"""
Run fuzzer.py with a specific seed.
Args:
seed: The seed value to pass to fuzzer.py
template: The template to use for code generation
supported_ops: Comma-separated ops string with optional weights
Returns:
FuzzerResult dataclass instance
@ -115,6 +120,10 @@ def run_fuzzer_with_seed(seed: int, template: str = "default") -> FuzzerResult:
template,
]
# Append supported ops if provided
if supported_ops:
cmd.extend(["--supported-ops", supported_ops])
result = subprocess.run(
cmd,
capture_output=True,
@ -213,6 +222,7 @@ def run_multi_process_fuzzer(
seed_count: int = 100,
verbose: bool = False,
template: str = "default",
supported_ops: Optional[str] = None,
) -> None:
"""
Run the multi-process fuzzer.
@ -222,6 +232,8 @@ def run_multi_process_fuzzer(
seed_start: Starting seed value (inclusive)
seed_count: Number of seeds to run
verbose: Whether to print detailed output
template: The template to use for code generation
supported_ops: Comma-separated ops string with optional weights
"""
seeds = list(range(seed_start, seed_start + seed_count))
@ -250,7 +262,9 @@ def run_multi_process_fuzzer(
# Submit all seeds to the process pool
future_results = []
for seed in seeds:
future = pool.apply_async(run_fuzzer_with_seed, (seed, template))
future = pool.apply_async(
run_fuzzer_with_seed, (seed, template, supported_ops)
)
future_results.append(future)
# Set up progress bar

View File

@ -25,7 +25,15 @@ from torchfuzz.operators.nn_functional import (
ReLUOperator,
SoftmaxOperator,
)
from torchfuzz.operators.registry import get_operator, list_operators, register_operator
from torchfuzz.operators.registry import (
get_operator,
list_operators,
register_operator,
set_operator_weight,
set_operator_weight_by_torch_op,
set_operator_weights,
set_operator_weights_by_torch_op,
)
from torchfuzz.operators.scalar_pointwise import (
ScalarAddOperator,
ScalarDivOperator,
@ -75,4 +83,8 @@ __all__ = [
"get_operator",
"register_operator",
"list_operators",
"set_operator_weight",
"set_operator_weights",
"set_operator_weight_by_torch_op",
"set_operator_weights_by_torch_op",
]

View File

@ -9,9 +9,16 @@ from torchfuzz.tensor_fuzzer import Spec
class Operator(ABC):
"""Base class for all operators in torchfuzz."""
def __init__(self, name: str):
"""Initialize operator with name."""
def __init__(self, name: str, weight: float = 1.0):
"""Initialize operator with name and optional selection weight.
Args:
name: Unique operator name used in the registry
weight: Relative selection weight when sampling among compatible operators
(default 1.0). Higher values increase selection likelihood.
"""
self.name = name
self.weight: float = float(weight)
@property
@abstractmethod
@ -32,10 +39,13 @@ class Operator(ABC):
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""
Get input specifications for fuzzing. By default, delegates to decompose.
Leaf operators should override this to return an empty list.
Get input specifications for fuzzing.
Subclasses must implement this to return a list of input Specs that,
when used with this operator, can produce the given output_spec. Leaf
operators should return an empty list.
"""
return self.decompose(output_spec)
raise NotImplementedError("Subclasses must implement fuzz_inputs_specs()")
@abstractmethod
def codegen(
@ -44,6 +54,22 @@ class Operator(ABC):
"""Generate code for this operation."""
raise NotImplementedError("Subclasses must implement codegen()")
def get_weight(
self,
*,
target_spec: Optional[Spec] = None,
depth: Optional[int] = None,
stack_size: Optional[int] = None,
template: Optional[str] = None,
) -> float:
"""
Return the selection weight for this operator.
Subclasses may override to implement context-sensitive weighting.
The default implementation returns the static attribute `self.weight`.
"""
return self.weight
def __str__(self) -> str:
"""String representation of the operator."""
return f"{self.__class__.__name__}({self.name})"

View File

@ -61,6 +61,7 @@ class MMOperator(MatrixMultiplyOperator):
def __init__(self):
super().__init__("mm", "torch.mm")
self.weight = 5.0
def can_produce(self, output_spec: Spec) -> bool:
"""MM requires exactly 2D tensors."""
@ -141,6 +142,7 @@ class AddmmOperator(MatrixMultiplyOperator):
def __init__(self):
super().__init__("addmm", "torch.addmm")
self.weight = 5.0
def can_produce(self, output_spec: Spec) -> bool:
"""Addmm requires exactly 2D tensors."""
@ -229,6 +231,7 @@ class BmmOperator(MatrixMultiplyOperator):
def __init__(self):
super().__init__("bmm", "torch.bmm")
self.weight = 5.0
def can_produce(self, output_spec: Spec) -> bool:
"""Batch matrix multiply requires 3D tensors."""
@ -309,6 +312,7 @@ class MatmulOperator(MatrixMultiplyOperator):
def __init__(self):
super().__init__("matmul", "torch.matmul")
self.weight = 500.0
def can_produce(self, output_spec: Spec) -> bool:
"""Matmul can handle various tensor dimensions >= 1."""

View File

@ -434,6 +434,7 @@ class RMSNormOperator(Operator):
def __init__(self):
super().__init__("torch.nn.functional.rms_norm")
self.weight = 5.0
@property
def torch_op_name(self) -> Optional[str]:

View File

@ -154,3 +154,52 @@ def register_operator(operator: Operator):
def list_operators() -> dict[str, Operator]:
"""List all operators in the global registry."""
return _global_registry.list_operators()
def set_operator_weight(op_name: str, weight: float) -> None:
"""Set the selection weight for a specific operator.
Args:
op_name: The registered operator name (e.g., "add", "arg") OR fully-qualified torch op
(e.g., "torch.nn.functional.relu", "torch.matmul")
weight: New relative selection weight (must be > 0)
"""
if weight <= 0:
raise ValueError("Operator weight must be > 0")
# Try by registry key
op = _global_registry.get(op_name)
if op is not None:
op.weight = float(weight)
return
# Fallback: try to locate by fully-qualified torch op name
for candidate in _global_registry.list_operators().values():
if getattr(candidate, "torch_op_name", None) == op_name:
candidate.weight = float(weight)
return
raise KeyError(f"Operator '{op_name}' not found by registry name or torch op name")
def set_operator_weights(weights: dict[str, float]) -> None:
"""Bulk-update operator weights from a mapping of name -> weight."""
for name, w in weights.items():
set_operator_weight(name, w)
def set_operator_weight_by_torch_op(torch_op_name: str, weight: float) -> None:
"""Set operator weight by fully-qualified torch op name."""
if weight <= 0:
raise ValueError("Operator weight must be > 0")
for candidate in _global_registry.list_operators().values():
if getattr(candidate, "torch_op_name", None) == torch_op_name:
candidate.weight = float(weight)
return
raise KeyError(f"Torch op '{torch_op_name}' not found in registry")
def set_operator_weights_by_torch_op(weights: dict[str, float]) -> None:
"""Bulk-update weights by fully-qualified torch op names."""
for name, w in weights.items():
set_operator_weight_by_torch_op(name, w)

View File

@ -86,8 +86,9 @@ class PointwiseOperator(Operator):
class AddOperator(PointwiseOperator):
"""Operator for element-wise addition."""
def __init__(self):
def __init__(self, weight: float = 1.0):
super().__init__("add", "torch.add", "+")
self.weight = float(weight)
class MulOperator(PointwiseOperator):

View File

@ -30,8 +30,15 @@ def _get_cached_operators():
return _CACHED_OPERATORS
def _get_template_filtered_operators(template: str = "default"):
"""Get operators filtered by template's supported_ops."""
def _get_template_filtered_operators(
template: str = "default", supported_ops: Optional[list[str]] = None
):
"""Get operators filtered by template's supported_ops, with user override.
If supported_ops is provided, it takes precedence and is used to filter the
registry. Otherwise, the template's supported_ops are used. If neither are
specified, all operators are returned.
"""
# Instantiate template
if template == "dtensor":
from torchfuzz.codegen import DTensorFuzzTemplate
@ -48,11 +55,14 @@ def _get_template_filtered_operators(template: str = "default"):
all_operators = _get_cached_operators()
# Determine allowed ops list
allowed_ops = supported_ops if supported_ops else fuzz_template.supported_ops
# If no supported_ops specified, return all operators
if not fuzz_template.supported_ops:
if not allowed_ops:
return all_operators
# Filter operators based on supported_ops
# Filter operators based on allowed_ops
filtered_ops = {}
for op_name, operator in all_operators.items():
@ -66,9 +76,9 @@ def _get_template_filtered_operators(template: str = "default"):
filtered_ops[op_name] = operator
continue
# Check if the operator supports any of the template's operations
# Check if the operator supports any of the allowed operations
should_include = False
for supported_op in fuzz_template.supported_ops:
for supported_op in allowed_ops:
# Direct torch operation matching
if torch_op == supported_op:
should_include = True
@ -260,7 +270,11 @@ def fuzz_spec(template: str = "default") -> Spec:
def fuzz_op(
target_spec: Spec, depth, stack_size, template: str = "default"
target_spec: Spec,
depth,
stack_size,
template: str = "default",
supported_ops: Optional[list[str]] = None,
) -> tuple[str, list[Spec]]:
"""
Given an output specification, returns an operation that can
@ -277,7 +291,7 @@ def fuzz_op(
describes the layout requirements for the operation's inputs
"""
# Get template-filtered operators
available_operators = _get_template_filtered_operators(template)
available_operators = _get_template_filtered_operators(template, supported_ops)
# Filter operators that can produce the target spec
compatible_ops = []
@ -306,24 +320,84 @@ def fuzz_op(
if not leaf_ops:
# If no leaf ops can produce this spec, fallback to arg
return _get_arg_args_specs(target_spec)
chosen_op_name, chosen_operator = random.choice(leaf_ops)
# Weighted choice among leaf ops
leaf_weights = [
op.get_weight(
target_spec=target_spec,
depth=depth,
stack_size=stack_size,
template=template,
)
for _, op in leaf_ops
]
idx = random.choices(range(len(leaf_ops)), weights=leaf_weights, k=1)[0]
chosen_op_name, chosen_operator = leaf_ops[idx]
else:
# At higher depths, choose between leaf and non-leaf operations
# Reduce probability of leaf operations when stack_size < 10
if (stack_size < 10 or depth > 7) and non_leaf_ops:
# 80% chance of non-leaf, 20% chance of leaf
if random.random() < 0.8:
chosen_op_name, chosen_operator = random.choice(non_leaf_ops)
# Weighted choice among non-leaf ops
nonleaf_weights = [
op.get_weight(
target_spec=target_spec,
depth=depth,
stack_size=stack_size,
template=template,
)
for _, op in non_leaf_ops
]
idx = random.choices(
range(len(non_leaf_ops)), weights=nonleaf_weights, k=1
)[0]
chosen_op_name, chosen_operator = non_leaf_ops[idx]
else:
chosen_op_name, chosen_operator = (
random.choice(leaf_ops) if leaf_ops else random.choice(non_leaf_ops)
)
if leaf_ops:
leaf_weights = [
op.get_weight(
target_spec=target_spec,
depth=depth,
stack_size=stack_size,
template=template,
)
for _, op in leaf_ops
]
idx = random.choices(
range(len(leaf_ops)), weights=leaf_weights, k=1
)[0]
chosen_op_name, chosen_operator = leaf_ops[idx]
else:
nonleaf_weights = [
op.get_weight(
target_spec=target_spec,
depth=depth,
stack_size=stack_size,
template=template,
)
for _, op in non_leaf_ops
]
idx = random.choices(
range(len(non_leaf_ops)), weights=nonleaf_weights, k=1
)[0]
chosen_op_name, chosen_operator = non_leaf_ops[idx]
else:
# Normal probability distribution
# Normal probability distribution over all ops
all_ops = non_leaf_ops + leaf_ops
chosen_op_name, chosen_operator = (
random.choice(all_ops) if all_ops else ("arg", get_operator("arg"))
)
if all_ops:
all_weights = [
op.get_weight(
target_spec=target_spec,
depth=depth,
stack_size=stack_size,
template=template,
)
for _, op in all_ops
]
idx = random.choices(range(len(all_ops)), weights=all_weights, k=1)[0]
chosen_op_name, chosen_operator = all_ops[idx]
else:
chosen_op_name, chosen_operator = ("arg", get_operator("arg"))
if chosen_operator is None:
# If no operator found, fallback to arg
@ -354,6 +428,7 @@ def fuzz_operation_graph(
max_depth: int = 7,
seed: Optional[int] = None,
template: str = "default",
supported_ops: Optional[list[str]] = None,
) -> OperationGraph:
"""
Generate a graph of operations that produces the target specification.
@ -394,7 +469,7 @@ def fuzz_operation_graph(
nonlocal node_counter
# Generate new operation
op_name, input_specs = fuzz_op(spec, depth, stack_size, template)
op_name, input_specs = fuzz_op(spec, depth, stack_size, template, supported_ops)
# Create unique node ID
node_id = f"node_{node_counter}"