mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update torch.masked.mean to upcast dtype for bool tensors (#139999)
When calling `torch.masked.mean(...)` with a boolean tensor, the dtype is inferred to be bool. When the mean is being computed, the sum operator is used. When the sum operator is used with dtype=torch.bool, the result is clamped to True (1) leading to an incorrect mean being calculated. The below example shows how the incorrect result occurs: ``` a = torch.tensor([True, True]) count = torch.sum(torch.ones(a.shape, dtype=torch.int64)) # 2 total = torch.sum(a, dtype=torch.bool) # True (1) mean = total / count # 0.5 ``` This PR upcasts the dtype used for the sumation to int32 in the case of bool tensors allowing for the correct result to be computed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139999 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
60a505022f
commit
a5051a9521
@ -4882,7 +4882,6 @@ class CPUReproTests(TestCase):
|
||||
@requires_vectorization
|
||||
def test_bool_reduction_vec(self):
|
||||
for op in (
|
||||
torch.masked.mean,
|
||||
torch.any,
|
||||
torch.min,
|
||||
torch.max,
|
||||
|
@ -1384,8 +1384,16 @@ elements, have ``nan`` values.
|
||||
{reduction_args}
|
||||
|
||||
{reduction_example}"""
|
||||
dtype_source = "Optional"
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
dtype_source = "Input"
|
||||
|
||||
if not (dtype.is_floating_point or dtype.is_complex):
|
||||
raise ValueError(
|
||||
f"mean(): Could not infer output dtype. {dtype_source} dtype must be either "
|
||||
f"a floating point or complex dtype. Got: {dtype}"
|
||||
)
|
||||
if input.layout == torch.strided:
|
||||
if mask is None:
|
||||
# TODO: compute count analytically
|
||||
|
@ -769,26 +769,8 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
promotes_int_to_float=True,
|
||||
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
|
||||
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestReductions",
|
||||
"test_ref_duplicate_values",
|
||||
dtypes=(torch.bool,),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestReductions",
|
||||
"test_reference_masked",
|
||||
dtypes=(torch.bool,),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestReductions",
|
||||
"test_ref_small_input",
|
||||
dtypes=(torch.bool,),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestNormalizeOperators",
|
||||
|
Reference in New Issue
Block a user