Files
pytorch/tools/experimental/torchfuzz/operators/unique.py

57 lines
2.3 KiB
Python

"""Unique operator implementation."""
from typing import Optional
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
class UniqueOperator(Operator):
"""Operator for finding unique elements in a tensor."""
def __init__(self):
super().__init__("unique")
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.unique"
def can_produce(self, output_spec: Spec) -> bool:
"""Unique can produce 1D tensor outputs of arbitrary length without guards.
We will synthesize an input with exactly the desired number of unique
elements so that torch.unique returns the target size deterministically.
"""
return isinstance(output_spec, TensorSpec) and len(output_spec.size) == 1
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
"""Generate input spec for unique operation."""
if not isinstance(output_spec, TensorSpec):
raise ValueError("UniqueOperator can only produce TensorSpec outputs")
# Input can be any tensor - unique will flatten and find unique values
input_spec = TensorSpec(
size=(2, 3), # Fixed size for consistency
stride=(3, 1), # Contiguous
dtype=output_spec.dtype, # Match output dtype
)
return [input_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for unique with deterministic target size input (no guards)."""
if len(input_names) != 1:
raise ValueError("UniqueOperator requires exactly one input")
# Desired output length and target dtype
desired_len = output_spec.size[0] if isinstance(output_spec, TensorSpec) else 0
# Synthesize in a wide dtype (int64) to guarantee desired_len distinct values,
# apply unique, then cast to the target dtype. No conditionals or guards.
return (
f"_inp_unique_wide = torch.arange({desired_len}, device={input_names[0]}.device, dtype=torch.int64)\n"
f"_uniq_wide = torch.unique(_inp_unique_wide)\n"
f"{output_name} = _uniq_wide.to({input_names[0]}.dtype)"
)