mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] fix bool propagation (#164003)
bools can't propogate through the current pointwise ops such as add/mul. once we add more that can, we'll probably want to add an additional subclass that supports pointwise bools, but for now just don't allow it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164003 Approved by: https://github.com/pianpwk ghstack dependencies: #163743, #163812, #163890, #164002
This commit is contained in:
committed by
PyTorch MergeBot
parent
280e712c13
commit
dcb8af7501
@ -3,6 +3,8 @@
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec
|
||||
|
||||
@ -21,6 +23,8 @@ class ScalarPointwiseOperator(Operator):
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Scalar pointwise operations can only produce scalars."""
|
||||
if output_spec.dtype == torch.bool:
|
||||
return False
|
||||
return isinstance(output_spec, ScalarSpec)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||
|
@ -3,6 +3,8 @@
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
from torchfuzz.type_promotion import (
|
||||
@ -27,6 +29,10 @@ class PointwiseOperator(Operator):
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Tensor pointwise operations can produce tensors but not scalars."""
|
||||
if not super().can_produce(output_spec):
|
||||
return False
|
||||
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
|
||||
return False
|
||||
return isinstance(output_spec, TensorSpec)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||
|
@ -136,6 +136,9 @@ def get_promotion_table_for_strings() -> dict:
|
||||
("int32", "int64"),
|
||||
("int64", "int32"),
|
||||
],
|
||||
"bool": [
|
||||
("bool", "bool"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user