From dcb8af750133aee6e8fac5de0b94072e35bba5bb Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Fri, 26 Sep 2025 18:28:36 -0700 Subject: [PATCH] [torchfuzz] fix bool propagation (#164003) bools can't propogate through the current pointwise ops such as add/mul. once we add more that can, we'll probably want to add an additional subclass that supports pointwise bools, but for now just don't allow it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164003 Approved by: https://github.com/pianpwk ghstack dependencies: #163743, #163812, #163890, #164002 --- .../dynamic_shapes/torchfuzz/operators/scalar_pointwise.py | 4 ++++ .../dynamic_shapes/torchfuzz/operators/tensor_pointwise.py | 6 ++++++ .../experimental/dynamic_shapes/torchfuzz/type_promotion.py | 3 +++ 3 files changed, 13 insertions(+) diff --git a/tools/experimental/dynamic_shapes/torchfuzz/operators/scalar_pointwise.py b/tools/experimental/dynamic_shapes/torchfuzz/operators/scalar_pointwise.py index f5d4b6b36b22..46e72c05bce5 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/operators/scalar_pointwise.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/operators/scalar_pointwise.py @@ -3,6 +3,8 @@ import random from typing import Optional +import torch + from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import ScalarSpec, Spec @@ -21,6 +23,8 @@ class ScalarPointwiseOperator(Operator): def can_produce(self, output_spec: Spec) -> bool: """Scalar pointwise operations can only produce scalars.""" + if output_spec.dtype == torch.bool: + return False return isinstance(output_spec, ScalarSpec) def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]: diff --git a/tools/experimental/dynamic_shapes/torchfuzz/operators/tensor_pointwise.py b/tools/experimental/dynamic_shapes/torchfuzz/operators/tensor_pointwise.py index 40453857b727..4f808b623077 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/operators/tensor_pointwise.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/operators/tensor_pointwise.py @@ -3,6 +3,8 @@ import random from typing import Optional +import torch + from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import Spec, TensorSpec from torchfuzz.type_promotion import ( @@ -27,6 +29,10 @@ class PointwiseOperator(Operator): def can_produce(self, output_spec: Spec) -> bool: """Tensor pointwise operations can produce tensors but not scalars.""" + if not super().can_produce(output_spec): + return False + if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool: + return False return isinstance(output_spec, TensorSpec) def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]: diff --git a/tools/experimental/dynamic_shapes/torchfuzz/type_promotion.py b/tools/experimental/dynamic_shapes/torchfuzz/type_promotion.py index 102338958dff..db48b87c0b5e 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/type_promotion.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/type_promotion.py @@ -136,6 +136,9 @@ def get_promotion_table_for_strings() -> dict: ("int32", "int64"), ("int64", "int32"), ], + "bool": [ + ("bool", "bool"), + ], }