mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] add support for operator weights (#164649)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164649 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647
This commit is contained in:
committed by
PyTorch MergeBot
parent
ded099ecbf
commit
5fe7f29b9e
@ -123,6 +123,28 @@ python fuzzer.py --seed 42 --log-level DEBUG --max-depth 5
|
||||
| `--verbose` | Print detailed output for all runs | `--verbose` |
|
||||
| `--template NAME` | Template to use for all runs | `--template default` |
|
||||
|
||||
## Restricting supported ops and weighting examples
|
||||
|
||||
You can restrict the fuzzer to a specific set of fully-qualified torch ops and optionally weight them to bias sampling.
|
||||
|
||||
- Restrict to only torch.add and torch.matmul (equal likelihood):
|
||||
|
||||
```bash
|
||||
python fuzzer.py --seed 42 \
|
||||
--supported-ops "torch.add,torch.matmul"
|
||||
```
|
||||
|
||||
- Restrict to only torch.add and torch.matmul, and make matmul 5x more likely than add:
|
||||
|
||||
```bash
|
||||
python fuzzer.py --seed 42 \
|
||||
--supported-ops "torch.add,torch.matmul=5"
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Use fully-qualified torch op names (e.g., torch.matmul, torch.nn.functional.rms_norm).
|
||||
- Weights must be > 0. If both --supported-ops and --op-weights specify a weight for the same op, the value from --supported-ops takes precedence.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
@ -20,11 +20,42 @@ from torchfuzz.runner import ProgramRunner
|
||||
from torchfuzz.visualize_graph import visualize_operation_graph
|
||||
|
||||
|
||||
def _parse_supported_ops_with_weights(spec: str) -> tuple[list[str], dict[str, float]]:
|
||||
"""Parse --supported-ops string.
|
||||
|
||||
Format: comma-separated fully-qualified torch ops, each optionally with =weight.
|
||||
Example: "torch.matmul=5,torch.nn.functional.rms_norm=5,torch.add"
|
||||
Returns (ops_list, weights_dict)
|
||||
"""
|
||||
ops: list[str] = []
|
||||
weights: dict[str, float] = {}
|
||||
if not spec:
|
||||
return ops, weights
|
||||
for entry in spec.split(","):
|
||||
entry = entry.strip()
|
||||
if not entry:
|
||||
continue
|
||||
if "=" in entry:
|
||||
name, w = entry.split("=", 1)
|
||||
name = name.strip()
|
||||
try:
|
||||
weight = float(w.strip())
|
||||
except ValueError:
|
||||
continue
|
||||
ops.append(name)
|
||||
weights[name] = weight
|
||||
else:
|
||||
ops.append(entry)
|
||||
return ops, weights
|
||||
|
||||
|
||||
def fuzz_and_execute(
|
||||
seed: Optional[int] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
log_at_faluire: bool = False,
|
||||
template: str = "default",
|
||||
supported_ops: Optional[list[str]] = None,
|
||||
op_weights: Optional[dict[str, float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate a fuzzed operation stack, convert it to Python code, and execute it.
|
||||
@ -113,6 +144,12 @@ def fuzz_and_execute(
|
||||
logger.debug("⏱️ Step 1: Generating target spec...")
|
||||
start_time = time.time()
|
||||
target_spec = fuzz_spec(template)
|
||||
|
||||
# Apply user-specified operator weights (if provided)
|
||||
if op_weights:
|
||||
from torchfuzz.operators import set_operator_weights
|
||||
|
||||
set_operator_weights(op_weights)
|
||||
logger.debug(
|
||||
" Completed in %.3fs - %s", time.time() - start_time, target_spec
|
||||
)
|
||||
@ -120,7 +157,11 @@ def fuzz_and_execute(
|
||||
logger.debug("⏱️ Step 2: Generating operation graph...")
|
||||
start_time = time.time()
|
||||
operation_graph = fuzz_operation_graph(
|
||||
target_spec, max_depth=max_depth, seed=seed, template=template
|
||||
target_spec,
|
||||
max_depth=max_depth,
|
||||
seed=seed,
|
||||
template=template,
|
||||
supported_ops=supported_ops,
|
||||
)
|
||||
|
||||
# Extract and print operation statistics
|
||||
@ -226,6 +267,15 @@ if __name__ == "__main__":
|
||||
default="default",
|
||||
help="Template to use for code generation (default: default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--supported-ops",
|
||||
type=str,
|
||||
help=(
|
||||
"Comma-separated fully-qualified torch ops to allow, each optionally with =weight. "
|
||||
"Examples: 'torch.matmul,torch.nn.functional.rms_norm' or "
|
||||
"'torch.matmul=5,torch.nn.functional.rms_norm=5'. Overrides template supported ops."
|
||||
),
|
||||
)
|
||||
|
||||
# Multi-process fuzzing arguments
|
||||
parser.add_argument(
|
||||
@ -272,8 +322,20 @@ if __name__ == "__main__":
|
||||
if args.seed is not None or args.single:
|
||||
# Single seed execution mode
|
||||
print("Running single fuzz_and_execute...")
|
||||
# Parse supported ops and any inline weights from that flag
|
||||
parsed_supported_ops: Optional[list[str]] = None
|
||||
parsed_weights: dict[str, float] = {}
|
||||
if args.supported_ops:
|
||||
parsed_supported_ops, parsed_weights = _parse_supported_ops_with_weights(
|
||||
args.supported_ops
|
||||
)
|
||||
|
||||
fuzz_and_execute(
|
||||
seed=args.seed, max_depth=args.max_depth, template=args.template
|
||||
seed=args.seed,
|
||||
max_depth=args.max_depth,
|
||||
template=args.template,
|
||||
supported_ops=parsed_supported_ops,
|
||||
op_weights=(parsed_weights if parsed_weights else None),
|
||||
)
|
||||
elif args.start is not None or args.count is not None:
|
||||
# Multi-process fuzzing mode
|
||||
@ -305,6 +367,7 @@ if __name__ == "__main__":
|
||||
seed_count=args.count,
|
||||
verbose=args.verbose,
|
||||
template=args.template,
|
||||
supported_ops=args.supported_ops,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {str(e)}")
|
||||
|
@ -90,13 +90,18 @@ def is_ignored_output(output: str) -> int:
|
||||
return -1
|
||||
|
||||
|
||||
def run_fuzzer_with_seed(seed: int, template: str = "default") -> FuzzerResult:
|
||||
def run_fuzzer_with_seed(
|
||||
seed: int,
|
||||
template: str = "default",
|
||||
supported_ops: Optional[str] = None,
|
||||
) -> FuzzerResult:
|
||||
"""
|
||||
Run fuzzer.py with a specific seed.
|
||||
|
||||
Args:
|
||||
seed: The seed value to pass to fuzzer.py
|
||||
template: The template to use for code generation
|
||||
supported_ops: Comma-separated ops string with optional weights
|
||||
|
||||
Returns:
|
||||
FuzzerResult dataclass instance
|
||||
@ -115,6 +120,10 @@ def run_fuzzer_with_seed(seed: int, template: str = "default") -> FuzzerResult:
|
||||
template,
|
||||
]
|
||||
|
||||
# Append supported ops if provided
|
||||
if supported_ops:
|
||||
cmd.extend(["--supported-ops", supported_ops])
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
@ -213,6 +222,7 @@ def run_multi_process_fuzzer(
|
||||
seed_count: int = 100,
|
||||
verbose: bool = False,
|
||||
template: str = "default",
|
||||
supported_ops: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the multi-process fuzzer.
|
||||
@ -222,6 +232,8 @@ def run_multi_process_fuzzer(
|
||||
seed_start: Starting seed value (inclusive)
|
||||
seed_count: Number of seeds to run
|
||||
verbose: Whether to print detailed output
|
||||
template: The template to use for code generation
|
||||
supported_ops: Comma-separated ops string with optional weights
|
||||
"""
|
||||
seeds = list(range(seed_start, seed_start + seed_count))
|
||||
|
||||
@ -250,7 +262,9 @@ def run_multi_process_fuzzer(
|
||||
# Submit all seeds to the process pool
|
||||
future_results = []
|
||||
for seed in seeds:
|
||||
future = pool.apply_async(run_fuzzer_with_seed, (seed, template))
|
||||
future = pool.apply_async(
|
||||
run_fuzzer_with_seed, (seed, template, supported_ops)
|
||||
)
|
||||
future_results.append(future)
|
||||
|
||||
# Set up progress bar
|
||||
|
@ -25,7 +25,15 @@ from torchfuzz.operators.nn_functional import (
|
||||
ReLUOperator,
|
||||
SoftmaxOperator,
|
||||
)
|
||||
from torchfuzz.operators.registry import get_operator, list_operators, register_operator
|
||||
from torchfuzz.operators.registry import (
|
||||
get_operator,
|
||||
list_operators,
|
||||
register_operator,
|
||||
set_operator_weight,
|
||||
set_operator_weight_by_torch_op,
|
||||
set_operator_weights,
|
||||
set_operator_weights_by_torch_op,
|
||||
)
|
||||
from torchfuzz.operators.scalar_pointwise import (
|
||||
ScalarAddOperator,
|
||||
ScalarDivOperator,
|
||||
@ -75,4 +83,8 @@ __all__ = [
|
||||
"get_operator",
|
||||
"register_operator",
|
||||
"list_operators",
|
||||
"set_operator_weight",
|
||||
"set_operator_weights",
|
||||
"set_operator_weight_by_torch_op",
|
||||
"set_operator_weights_by_torch_op",
|
||||
]
|
||||
|
@ -9,9 +9,16 @@ from torchfuzz.tensor_fuzzer import Spec
|
||||
class Operator(ABC):
|
||||
"""Base class for all operators in torchfuzz."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""Initialize operator with name."""
|
||||
def __init__(self, name: str, weight: float = 1.0):
|
||||
"""Initialize operator with name and optional selection weight.
|
||||
|
||||
Args:
|
||||
name: Unique operator name used in the registry
|
||||
weight: Relative selection weight when sampling among compatible operators
|
||||
(default 1.0). Higher values increase selection likelihood.
|
||||
"""
|
||||
self.name = name
|
||||
self.weight: float = float(weight)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -32,10 +39,13 @@ class Operator(ABC):
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""
|
||||
Get input specifications for fuzzing. By default, delegates to decompose.
|
||||
Leaf operators should override this to return an empty list.
|
||||
Get input specifications for fuzzing.
|
||||
|
||||
Subclasses must implement this to return a list of input Specs that,
|
||||
when used with this operator, can produce the given output_spec. Leaf
|
||||
operators should return an empty list.
|
||||
"""
|
||||
return self.decompose(output_spec)
|
||||
raise NotImplementedError("Subclasses must implement fuzz_inputs_specs()")
|
||||
|
||||
@abstractmethod
|
||||
def codegen(
|
||||
@ -44,6 +54,22 @@ class Operator(ABC):
|
||||
"""Generate code for this operation."""
|
||||
raise NotImplementedError("Subclasses must implement codegen()")
|
||||
|
||||
def get_weight(
|
||||
self,
|
||||
*,
|
||||
target_spec: Optional[Spec] = None,
|
||||
depth: Optional[int] = None,
|
||||
stack_size: Optional[int] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Return the selection weight for this operator.
|
||||
|
||||
Subclasses may override to implement context-sensitive weighting.
|
||||
The default implementation returns the static attribute `self.weight`.
|
||||
"""
|
||||
return self.weight
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the operator."""
|
||||
return f"{self.__class__.__name__}({self.name})"
|
||||
|
@ -61,6 +61,7 @@ class MMOperator(MatrixMultiplyOperator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("mm", "torch.mm")
|
||||
self.weight = 5.0
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""MM requires exactly 2D tensors."""
|
||||
@ -141,6 +142,7 @@ class AddmmOperator(MatrixMultiplyOperator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("addmm", "torch.addmm")
|
||||
self.weight = 5.0
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Addmm requires exactly 2D tensors."""
|
||||
@ -229,6 +231,7 @@ class BmmOperator(MatrixMultiplyOperator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("bmm", "torch.bmm")
|
||||
self.weight = 5.0
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Batch matrix multiply requires 3D tensors."""
|
||||
@ -309,6 +312,7 @@ class MatmulOperator(MatrixMultiplyOperator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("matmul", "torch.matmul")
|
||||
self.weight = 500.0
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Matmul can handle various tensor dimensions >= 1."""
|
||||
|
@ -434,6 +434,7 @@ class RMSNormOperator(Operator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.rms_norm")
|
||||
self.weight = 5.0
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
|
@ -154,3 +154,52 @@ def register_operator(operator: Operator):
|
||||
def list_operators() -> dict[str, Operator]:
|
||||
"""List all operators in the global registry."""
|
||||
return _global_registry.list_operators()
|
||||
|
||||
|
||||
def set_operator_weight(op_name: str, weight: float) -> None:
|
||||
"""Set the selection weight for a specific operator.
|
||||
|
||||
Args:
|
||||
op_name: The registered operator name (e.g., "add", "arg") OR fully-qualified torch op
|
||||
(e.g., "torch.nn.functional.relu", "torch.matmul")
|
||||
weight: New relative selection weight (must be > 0)
|
||||
"""
|
||||
if weight <= 0:
|
||||
raise ValueError("Operator weight must be > 0")
|
||||
|
||||
# Try by registry key
|
||||
op = _global_registry.get(op_name)
|
||||
if op is not None:
|
||||
op.weight = float(weight)
|
||||
return
|
||||
|
||||
# Fallback: try to locate by fully-qualified torch op name
|
||||
for candidate in _global_registry.list_operators().values():
|
||||
if getattr(candidate, "torch_op_name", None) == op_name:
|
||||
candidate.weight = float(weight)
|
||||
return
|
||||
|
||||
raise KeyError(f"Operator '{op_name}' not found by registry name or torch op name")
|
||||
|
||||
|
||||
def set_operator_weights(weights: dict[str, float]) -> None:
|
||||
"""Bulk-update operator weights from a mapping of name -> weight."""
|
||||
for name, w in weights.items():
|
||||
set_operator_weight(name, w)
|
||||
|
||||
|
||||
def set_operator_weight_by_torch_op(torch_op_name: str, weight: float) -> None:
|
||||
"""Set operator weight by fully-qualified torch op name."""
|
||||
if weight <= 0:
|
||||
raise ValueError("Operator weight must be > 0")
|
||||
for candidate in _global_registry.list_operators().values():
|
||||
if getattr(candidate, "torch_op_name", None) == torch_op_name:
|
||||
candidate.weight = float(weight)
|
||||
return
|
||||
raise KeyError(f"Torch op '{torch_op_name}' not found in registry")
|
||||
|
||||
|
||||
def set_operator_weights_by_torch_op(weights: dict[str, float]) -> None:
|
||||
"""Bulk-update weights by fully-qualified torch op names."""
|
||||
for name, w in weights.items():
|
||||
set_operator_weight_by_torch_op(name, w)
|
||||
|
@ -86,8 +86,9 @@ class PointwiseOperator(Operator):
|
||||
class AddOperator(PointwiseOperator):
|
||||
"""Operator for element-wise addition."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, weight: float = 1.0):
|
||||
super().__init__("add", "torch.add", "+")
|
||||
self.weight = float(weight)
|
||||
|
||||
|
||||
class MulOperator(PointwiseOperator):
|
||||
|
@ -30,8 +30,15 @@ def _get_cached_operators():
|
||||
return _CACHED_OPERATORS
|
||||
|
||||
|
||||
def _get_template_filtered_operators(template: str = "default"):
|
||||
"""Get operators filtered by template's supported_ops."""
|
||||
def _get_template_filtered_operators(
|
||||
template: str = "default", supported_ops: Optional[list[str]] = None
|
||||
):
|
||||
"""Get operators filtered by template's supported_ops, with user override.
|
||||
|
||||
If supported_ops is provided, it takes precedence and is used to filter the
|
||||
registry. Otherwise, the template's supported_ops are used. If neither are
|
||||
specified, all operators are returned.
|
||||
"""
|
||||
# Instantiate template
|
||||
if template == "dtensor":
|
||||
from torchfuzz.codegen import DTensorFuzzTemplate
|
||||
@ -48,11 +55,14 @@ def _get_template_filtered_operators(template: str = "default"):
|
||||
|
||||
all_operators = _get_cached_operators()
|
||||
|
||||
# Determine allowed ops list
|
||||
allowed_ops = supported_ops if supported_ops else fuzz_template.supported_ops
|
||||
|
||||
# If no supported_ops specified, return all operators
|
||||
if not fuzz_template.supported_ops:
|
||||
if not allowed_ops:
|
||||
return all_operators
|
||||
|
||||
# Filter operators based on supported_ops
|
||||
# Filter operators based on allowed_ops
|
||||
filtered_ops = {}
|
||||
|
||||
for op_name, operator in all_operators.items():
|
||||
@ -66,9 +76,9 @@ def _get_template_filtered_operators(template: str = "default"):
|
||||
filtered_ops[op_name] = operator
|
||||
continue
|
||||
|
||||
# Check if the operator supports any of the template's operations
|
||||
# Check if the operator supports any of the allowed operations
|
||||
should_include = False
|
||||
for supported_op in fuzz_template.supported_ops:
|
||||
for supported_op in allowed_ops:
|
||||
# Direct torch operation matching
|
||||
if torch_op == supported_op:
|
||||
should_include = True
|
||||
@ -260,7 +270,11 @@ def fuzz_spec(template: str = "default") -> Spec:
|
||||
|
||||
|
||||
def fuzz_op(
|
||||
target_spec: Spec, depth, stack_size, template: str = "default"
|
||||
target_spec: Spec,
|
||||
depth,
|
||||
stack_size,
|
||||
template: str = "default",
|
||||
supported_ops: Optional[list[str]] = None,
|
||||
) -> tuple[str, list[Spec]]:
|
||||
"""
|
||||
Given an output specification, returns an operation that can
|
||||
@ -277,7 +291,7 @@ def fuzz_op(
|
||||
describes the layout requirements for the operation's inputs
|
||||
"""
|
||||
# Get template-filtered operators
|
||||
available_operators = _get_template_filtered_operators(template)
|
||||
available_operators = _get_template_filtered_operators(template, supported_ops)
|
||||
|
||||
# Filter operators that can produce the target spec
|
||||
compatible_ops = []
|
||||
@ -306,24 +320,84 @@ def fuzz_op(
|
||||
if not leaf_ops:
|
||||
# If no leaf ops can produce this spec, fallback to arg
|
||||
return _get_arg_args_specs(target_spec)
|
||||
chosen_op_name, chosen_operator = random.choice(leaf_ops)
|
||||
# Weighted choice among leaf ops
|
||||
leaf_weights = [
|
||||
op.get_weight(
|
||||
target_spec=target_spec,
|
||||
depth=depth,
|
||||
stack_size=stack_size,
|
||||
template=template,
|
||||
)
|
||||
for _, op in leaf_ops
|
||||
]
|
||||
idx = random.choices(range(len(leaf_ops)), weights=leaf_weights, k=1)[0]
|
||||
chosen_op_name, chosen_operator = leaf_ops[idx]
|
||||
else:
|
||||
# At higher depths, choose between leaf and non-leaf operations
|
||||
# Reduce probability of leaf operations when stack_size < 10
|
||||
if (stack_size < 10 or depth > 7) and non_leaf_ops:
|
||||
# 80% chance of non-leaf, 20% chance of leaf
|
||||
if random.random() < 0.8:
|
||||
chosen_op_name, chosen_operator = random.choice(non_leaf_ops)
|
||||
# Weighted choice among non-leaf ops
|
||||
nonleaf_weights = [
|
||||
op.get_weight(
|
||||
target_spec=target_spec,
|
||||
depth=depth,
|
||||
stack_size=stack_size,
|
||||
template=template,
|
||||
)
|
||||
for _, op in non_leaf_ops
|
||||
]
|
||||
idx = random.choices(
|
||||
range(len(non_leaf_ops)), weights=nonleaf_weights, k=1
|
||||
)[0]
|
||||
chosen_op_name, chosen_operator = non_leaf_ops[idx]
|
||||
else:
|
||||
chosen_op_name, chosen_operator = (
|
||||
random.choice(leaf_ops) if leaf_ops else random.choice(non_leaf_ops)
|
||||
)
|
||||
if leaf_ops:
|
||||
leaf_weights = [
|
||||
op.get_weight(
|
||||
target_spec=target_spec,
|
||||
depth=depth,
|
||||
stack_size=stack_size,
|
||||
template=template,
|
||||
)
|
||||
for _, op in leaf_ops
|
||||
]
|
||||
idx = random.choices(
|
||||
range(len(leaf_ops)), weights=leaf_weights, k=1
|
||||
)[0]
|
||||
chosen_op_name, chosen_operator = leaf_ops[idx]
|
||||
else:
|
||||
nonleaf_weights = [
|
||||
op.get_weight(
|
||||
target_spec=target_spec,
|
||||
depth=depth,
|
||||
stack_size=stack_size,
|
||||
template=template,
|
||||
)
|
||||
for _, op in non_leaf_ops
|
||||
]
|
||||
idx = random.choices(
|
||||
range(len(non_leaf_ops)), weights=nonleaf_weights, k=1
|
||||
)[0]
|
||||
chosen_op_name, chosen_operator = non_leaf_ops[idx]
|
||||
else:
|
||||
# Normal probability distribution
|
||||
# Normal probability distribution over all ops
|
||||
all_ops = non_leaf_ops + leaf_ops
|
||||
chosen_op_name, chosen_operator = (
|
||||
random.choice(all_ops) if all_ops else ("arg", get_operator("arg"))
|
||||
)
|
||||
if all_ops:
|
||||
all_weights = [
|
||||
op.get_weight(
|
||||
target_spec=target_spec,
|
||||
depth=depth,
|
||||
stack_size=stack_size,
|
||||
template=template,
|
||||
)
|
||||
for _, op in all_ops
|
||||
]
|
||||
idx = random.choices(range(len(all_ops)), weights=all_weights, k=1)[0]
|
||||
chosen_op_name, chosen_operator = all_ops[idx]
|
||||
else:
|
||||
chosen_op_name, chosen_operator = ("arg", get_operator("arg"))
|
||||
|
||||
if chosen_operator is None:
|
||||
# If no operator found, fallback to arg
|
||||
@ -354,6 +428,7 @@ def fuzz_operation_graph(
|
||||
max_depth: int = 7,
|
||||
seed: Optional[int] = None,
|
||||
template: str = "default",
|
||||
supported_ops: Optional[list[str]] = None,
|
||||
) -> OperationGraph:
|
||||
"""
|
||||
Generate a graph of operations that produces the target specification.
|
||||
@ -394,7 +469,7 @@ def fuzz_operation_graph(
|
||||
nonlocal node_counter
|
||||
|
||||
# Generate new operation
|
||||
op_name, input_specs = fuzz_op(spec, depth, stack_size, template)
|
||||
op_name, input_specs = fuzz_op(spec, depth, stack_size, template, supported_ops)
|
||||
|
||||
# Create unique node ID
|
||||
node_id = f"node_{node_counter}"
|
||||
|
Reference in New Issue
Block a user