mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164716 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687, #164688, #164693, #164694, #164715
68 lines
2.6 KiB
Python
68 lines
2.6 KiB
Python
"""Masked select operator implementation."""
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from torchfuzz.operators.base import Operator
|
|
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
|
|
|
|
|
class MaskedSelectOperator(Operator):
|
|
"""Operator for selecting elements from a tensor based on a mask."""
|
|
|
|
def __init__(self):
|
|
super().__init__("masked_select")
|
|
|
|
@property
|
|
def torch_op_name(self) -> Optional[str]:
|
|
"""Return the torch operation name."""
|
|
return "torch.masked_select"
|
|
|
|
def can_produce(self, output_spec: Spec) -> bool:
|
|
"""Masked select produces a 1D tensor; we'll synthesize inputs to match size."""
|
|
return isinstance(output_spec, TensorSpec) and len(output_spec.size) == 1
|
|
|
|
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
|
"""Generate input specs for masked_select operation."""
|
|
if not isinstance(output_spec, TensorSpec):
|
|
raise ValueError("MaskedSelectOperator can only produce TensorSpec outputs")
|
|
|
|
# Input tensor - can be any shape and type
|
|
input_tensor_spec = TensorSpec(
|
|
size=(2, 3), # Fixed size for consistency
|
|
stride=(3, 1), # Contiguous
|
|
dtype=output_spec.dtype, # Match output dtype
|
|
)
|
|
|
|
# Mask tensor - must be boolean and broadcastable to input
|
|
mask_spec = TensorSpec(
|
|
size=(2, 3), # Same size as input for simplicity
|
|
stride=(3, 1), # Contiguous
|
|
dtype=torch.bool,
|
|
)
|
|
|
|
return [input_tensor_spec, mask_spec]
|
|
|
|
def codegen(
|
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
|
) -> str:
|
|
"""Generate code for masked_select with synthesized inputs to match size.
|
|
|
|
Constructs an input tensor and mask so that exactly k elements are selected,
|
|
where k = output_spec.size[0]. No data-dependent guards.
|
|
"""
|
|
if len(input_names) != 2:
|
|
raise ValueError("MaskedSelectOperator requires exactly two inputs")
|
|
if not isinstance(output_spec, TensorSpec) or len(output_spec.size) != 1:
|
|
raise ValueError("MaskedSelectOperator requires 1D TensorSpec output")
|
|
k = output_spec.size[0]
|
|
# Build a 1D input of length >= k and a mask with first k positions True
|
|
# Use input's device and output dtype to avoid mismatches
|
|
return (
|
|
f"_x_ms = torch.arange(max({k}, 1), device={input_names[0]}.device).to({input_names[0]}.dtype)\n"
|
|
f"_mask_ms = torch.zeros_like(_x_ms, dtype=torch.bool)\n"
|
|
f"_mask_ms[:{k}] = True\n"
|
|
f"{output_name} = torch.masked_select(_x_ms, _mask_ms)"
|
|
)
|