Update torch.masked.mean to upcast dtype for bool tensors (#139999)

When calling `torch.masked.mean(...)` with a boolean tensor, the dtype is inferred to be bool. When the mean is being computed, the sum operator is used. When the sum operator is used with dtype=torch.bool, the result is clamped to True (1) leading to an incorrect mean being calculated.

The below example shows how the incorrect result occurs:
```
a = torch.tensor([True, True])
count = torch.sum(torch.ones(a.shape, dtype=torch.int64)) # 2
total = torch.sum(a, dtype=torch.bool) # True (1)
mean = total / count # 0.5
```

This PR upcasts the dtype used for the sumation to int32 in the case of bool tensors allowing for the correct result to be computed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139999
Approved by: https://github.com/cpuhrsch
This commit is contained in:
George Wigley
2025-01-08 10:35:19 +00:00
committed by PyTorch MergeBot
parent 60a505022f
commit a5051a9521
3 changed files with 9 additions and 20 deletions

View File

@ -4882,7 +4882,6 @@ class CPUReproTests(TestCase):
@requires_vectorization
def test_bool_reduction_vec(self):
for op in (
torch.masked.mean,
torch.any,
torch.min,
torch.max,

View File

@ -1384,8 +1384,16 @@ elements, have ``nan`` values.
{reduction_args}
{reduction_example}"""
dtype_source = "Optional"
if dtype is None:
dtype = input.dtype
dtype_source = "Input"
if not (dtype.is_floating_point or dtype.is_complex):
raise ValueError(
f"mean(): Could not infer output dtype. {dtype_source} dtype must be either "
f"a floating point or complex dtype. Got: {dtype}"
)
if input.layout == torch.strided:
if mask is None:
# TODO: compute count analytically

View File

@ -769,26 +769,8 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestReductions",
"test_ref_duplicate_values",
dtypes=(torch.bool,),
),
DecorateInfo(
unittest.expectedFailure,
"TestReductions",
"test_reference_masked",
dtypes=(torch.bool,),
),
DecorateInfo(
unittest.expectedFailure,
"TestReductions",
"test_ref_small_input",
dtypes=(torch.bool,),
),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",