mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164688 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687
43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
"""Item operator implementation."""
|
|
|
|
from typing import Optional
|
|
|
|
from torchfuzz.operators.base import Operator
|
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
|
|
|
|
|
class ItemOperator(Operator):
|
|
"""Operator for converting 0-d tensor to scalar."""
|
|
|
|
def __init__(self):
|
|
super().__init__("item")
|
|
|
|
@property
|
|
def torch_op_name(self) -> Optional[str]:
|
|
"""Item is a tensor method, not a direct torch operation."""
|
|
return None
|
|
|
|
def can_produce(self, output_spec: Spec) -> bool:
|
|
"""Item produces scalars from 0-d tensors."""
|
|
return isinstance(output_spec, ScalarSpec)
|
|
|
|
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
|
|
"""Decompose scalar into a single-element tensor for item operation."""
|
|
if not isinstance(output_spec, ScalarSpec):
|
|
raise ValueError("ItemOperator can only produce ScalarSpec outputs")
|
|
|
|
# Create a tensor spec that can produce a scalar via .item()
|
|
# Use a 0-D tensor (scalar tensor) to ensure .item() works reliably
|
|
tensor_spec = TensorSpec(size=(), stride=(), dtype=output_spec.dtype)
|
|
|
|
return [tensor_spec]
|
|
|
|
def codegen(
|
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
|
) -> str:
|
|
"""Generate code for item operation."""
|
|
if len(input_names) != 1:
|
|
raise ValueError("ItemOperator requires exactly one input")
|
|
|
|
return f"{output_name} = {input_names[0]}.item()"
|