[torchfuzz] synthesize inputs for data dependent ops (#164716)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164716
Approved by: https://github.com/pianpwk
ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687, #164688, #164693, #164694, #164715
This commit is contained in:
bobrenjc93
2025-10-06 10:03:52 -07:00
committed by PyTorch MergeBot
parent 2a6cdba6e5
commit 4bcc05777e
4 changed files with 84 additions and 51 deletions

View File

@ -685,10 +685,19 @@ def generate_simple_operation_code(
if operator is not None:
# Use the class-based operator to generate code
code_line = operator.codegen(output_var, input_vars, output_spec)
# Add tensor descriptor comment
code = operator.codegen(output_var, input_vars, output_spec)
# Add tensor descriptor comment to the last emitted line
descriptor_comment = f"# {format_tensor_descriptor(output_spec)}"
return [code_line + " " + descriptor_comment]
if "\n" in code:
lines = code.split("\n")
# Attach comment to the last non-empty line
for i in range(len(lines) - 1, -1, -1):
if lines[i].strip():
lines[i] = lines[i] + " " + descriptor_comment
break
return lines
else:
return [code + " " + descriptor_comment]
else:
# Fallback for unknown operations
return [f"# Unknown operation: {op_name}"]

View File

@ -20,17 +20,8 @@ class MaskedSelectOperator(Operator):
return "torch.masked_select"
def can_produce(self, output_spec: Spec) -> bool:
"""Masked select produces a 1D tensor with data-dependent size."""
if not isinstance(output_spec, TensorSpec):
return False
# Output is always 1D with data-dependent size
# Be very restrictive to avoid shape mismatches
return (
len(output_spec.size) == 1
and output_spec.size[0] <= 10 # Reasonable size
and output_spec.dtype not in [torch.bool]
) # Avoid bool outputs
"""Masked select produces a 1D tensor; we'll synthesize inputs to match size."""
return isinstance(output_spec, TensorSpec) and len(output_spec.size) == 1
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
"""Generate input specs for masked_select operation."""
@ -56,10 +47,21 @@ class MaskedSelectOperator(Operator):
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for masked_select operation."""
"""Generate code for masked_select with synthesized inputs to match size.
Constructs an input tensor and mask so that exactly k elements are selected,
where k = output_spec.size[0]. No data-dependent guards.
"""
if len(input_names) != 2:
raise ValueError("MaskedSelectOperator requires exactly two inputs")
if not isinstance(output_spec, TensorSpec) or len(output_spec.size) != 1:
raise ValueError("MaskedSelectOperator requires 1D TensorSpec output")
k = output_spec.size[0]
# Build a 1D input of length >= k and a mask with first k positions True
# Use input's device and output dtype to avoid mismatches
return (
f"{output_name} = torch.masked_select({input_names[0]}, {input_names[1]})"
f"_x_ms = torch.arange(max({k}, 1), device={input_names[0]}.device).to({input_names[0]}.dtype)\n"
f"_mask_ms = torch.zeros_like(_x_ms, dtype=torch.bool)\n"
f"_mask_ms[:{k}] = True\n"
f"{output_name} = torch.masked_select(_x_ms, _mask_ms)"
)

View File

@ -20,39 +20,61 @@ class NonzeroOperator(Operator):
return "torch.nonzero"
def can_produce(self, output_spec: Spec) -> bool:
"""Nonzero produces a tensor with shape (n_nonzero, n_dims)."""
if not isinstance(output_spec, TensorSpec):
return False
"""Nonzero produces a tensor with shape (n_nonzero, n_dims).
# Output shape is (n_nonzero, n_dims) where both are data-dependent
# We can only produce integer tensors (indices) and only 2D tensors
# Restrict to very specific shapes to avoid shape mismatches
We can deterministically synthesize inputs to match any 2D int64 output
shape (k, d) without data-dependent guards by constructing an input with
exactly k non-zero elements and d dimensions.
"""
return (
output_spec.dtype in [torch.int64, torch.long]
isinstance(output_spec, TensorSpec)
and output_spec.dtype in [torch.int64, torch.long]
and len(output_spec.size) == 2
and output_spec.size[1] <= 4
) # Reasonable input dimensionality
)
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
"""Generate input spec for nonzero operation."""
"""Generate input spec for nonzero operation.
The actual values will be synthesized in codegen to achieve the target size.
"""
if not isinstance(output_spec, TensorSpec):
raise ValueError("NonzeroOperator can only produce TensorSpec outputs")
# Input can be any tensor type that supports comparison with zero
# Use boolean tensors for simplicity to ensure some nonzero elements
# Provide a placeholder spec; codegen will ignore the actual input content
# and synthesize a tensor with desired nonzero count and dimensionality.
d = output_spec.size[1]
input_spec = TensorSpec(
size=(3, 4), # Fixed size that will have some nonzero elements
stride=(4, 1), # Contiguous
dtype=torch.bool, # Boolean tensors are good for nonzero testing
size=tuple([1] * d) if d > 0 else (),
stride=tuple([1] * d) if d > 0 else (),
dtype=torch.bool,
)
return [input_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for nonzero operation."""
"""Generate code for nonzero using synthesized input to match target size.
No data-dependent conditionals/guards. Constructs an input with exactly
k = output_spec.size[0] non-zero elements and d = output_spec.size[1] dims,
then calls torch.nonzero on it.
"""
if len(input_names) != 1:
raise ValueError("NonzeroOperator requires exactly one input")
return f"{output_name} = torch.nonzero({input_names[0]})"
if not isinstance(output_spec, TensorSpec) or len(output_spec.size) != 2:
raise ValueError("NonzeroOperator requires 2D TensorSpec output")
k = output_spec.size[0]
d = output_spec.size[1]
# Construct concrete shape literal like (k, 1, 1, ...)
shape_elems = [str(k)] + ["1"] * max(0, d - 1)
shape_literal = (
"(" + ", ".join(shape_elems) + ("," if d == 1 else "") + ")"
if d > 0
else "()"
)
return (
f"_x_nz = torch.zeros({shape_literal}, dtype=torch.bool, device={input_names[0]}.device)\n"
f"_x_nz_flat = _x_nz.reshape(-1)\n"
f"_x_nz_flat[:{k}] = True\n"
f"{output_name} = torch.nonzero(_x_nz)"
)

View File

@ -2,8 +2,6 @@
from typing import Optional
import torch
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
@ -20,17 +18,12 @@ class UniqueOperator(Operator):
return "torch.unique"
def can_produce(self, output_spec: Spec) -> bool:
"""Unique produces a 1D tensor with data-dependent size."""
if not isinstance(output_spec, TensorSpec):
return False
"""Unique can produce 1D tensor outputs of arbitrary length without guards.
# Output is always 1D with data-dependent size
# Be very restrictive to avoid shape mismatches
return (
len(output_spec.size) == 1
and output_spec.size[0] <= 10 # Reasonable size
and output_spec.dtype not in [torch.bool]
) # Avoid bool outputs
We will synthesize an input with exactly the desired number of unique
elements so that torch.unique returns the target size deterministically.
"""
return isinstance(output_spec, TensorSpec) and len(output_spec.size) == 1
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
"""Generate input spec for unique operation."""
@ -49,8 +42,15 @@ class UniqueOperator(Operator):
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for unique operation."""
"""Generate code for unique with deterministic target size input (no guards)."""
if len(input_names) != 1:
raise ValueError("UniqueOperator requires exactly one input")
return f"{output_name} = torch.unique({input_names[0]})"
# Desired output length and target dtype
desired_len = output_spec.size[0] if isinstance(output_spec, TensorSpec) else 0
# Synthesize in a wide dtype (int64) to guarantee desired_len distinct values,
# apply unique, then cast to the target dtype. No conditionals or guards.
return (
f"_inp_unique_wide = torch.arange({desired_len}, device={input_names[0]}.device, dtype=torch.int64)\n"
f"_uniq_wide = torch.unique(_inp_unique_wide)\n"
f"{output_name} = _uniq_wide.to({input_names[0]}.dtype)"
)