[torchfuzz] fix bool propagation (#164003)

bools can't propogate through the current pointwise ops such as add/mul. once we add more that can, we'll probably want to add an additional subclass that supports pointwise bools, but for now just don't allow it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164003
Approved by: https://github.com/pianpwk
ghstack dependencies: #163743, #163812, #163890, #164002
This commit is contained in:
bobrenjc93
2025-09-26 18:28:36 -07:00
committed by PyTorch MergeBot
parent 280e712c13
commit dcb8af7501
3 changed files with 13 additions and 0 deletions

View File

@ -3,6 +3,8 @@
import random
from typing import Optional
import torch
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec
@ -21,6 +23,8 @@ class ScalarPointwiseOperator(Operator):
def can_produce(self, output_spec: Spec) -> bool:
"""Scalar pointwise operations can only produce scalars."""
if output_spec.dtype == torch.bool:
return False
return isinstance(output_spec, ScalarSpec)
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:

View File

@ -3,6 +3,8 @@
import random
from typing import Optional
import torch
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
from torchfuzz.type_promotion import (
@ -27,6 +29,10 @@ class PointwiseOperator(Operator):
def can_produce(self, output_spec: Spec) -> bool:
"""Tensor pointwise operations can produce tensors but not scalars."""
if not super().can_produce(output_spec):
return False
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
return False
return isinstance(output_spec, TensorSpec)
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:

View File

@ -136,6 +136,9 @@ def get_promotion_table_for_strings() -> dict:
("int32", "int64"),
("int64", "int32"),
],
"bool": [
("bool", "bool"),
],
}