[Inductor/Triton] Upcast FP16/BF16 math reductions to FP32 (#141052)

Summary:
Triton compiler does not automatically promote fp16/bf16 reductions to fp32  accumulation. This will result in significant accuracy issue.

This diff will upcast the input to FP32 for all math reductions `["welford_reduce", "welford_combine", "prod", "sum", "xor_sum"]`

Test Plan:
CI
```
python test/inductor/test_torchinductor.py TritonCodeGenTests.test_low_precision_reduction
```

Differential Revision: D65965032

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141052
Approved by: https://github.com/blaine-rister
This commit is contained in:
Xinran / Allan Rui
2025-01-04 07:57:08 +00:00
committed by PyTorch MergeBot
parent 816328fa51
commit 417d9c3522
2 changed files with 71 additions and 0 deletions

View File

@ -30,6 +30,7 @@ from sympy.printing.precedence import PRECEDENCE
import torch
import torch._logging
import torch.utils._pytree as pytree
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import identity, preserve_rng_state
from torch._prims_common import is_integer_dtype
@ -2284,6 +2285,27 @@ class TritonKernel(SIMDKernel):
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
def maybe_upcast(value: CSEVariable) -> CSEVariable:
# Math reductions in FP16/BF16 are less accurate because the Triton compiler does not
# automatically promote to FP32 for accumulation. Additionally, max/min reductions
# do not support FP16/BF16. We manually promote to FP32 here.
return (
ops.to_dtype(value, torch.float32)
if value.dtype
in [
torch.float16,
torch.bfloat16,
]
else value
)
original_dtypes = [val.dtype for val in pytree.tree_leaves(value)]
value = pytree.tree_map(maybe_upcast, value)
if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes):
# Only promote FB16/BF16; do not promote other integer/boolean dtypes
src_dtype = torch.promote_types(src_dtype, torch.float32)
dtype = torch.promote_types(dtype, torch.float32)
assert self.inside_reduction
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
self.filter_masks(masks)
@ -2545,10 +2567,30 @@ class TritonKernel(SIMDKernel):
if isinstance(result_var, tuple):
assert all(isinstance(x, TritonCSEVariable) for x in result_var)
self.outside_loop_vars.update(result_var)
# Match output dtype with input dtype
if reduction_type == "welford_reduce":
assert len(original_dtypes) == 1
original_dtypes = len(result_var) * original_dtypes
assert len(result_var) == len(original_dtypes)
for var, orig_dtype in zip(result_var, original_dtypes):
assert orig_dtype is not None
if var.dtype != orig_dtype:
self.post_loop_combine.writeline(
f"{var} = {var}.to({triton_compute_type(orig_dtype)})"
)
else:
assert isinstance(result_var, TritonCSEVariable)
self.outside_loop_vars.add(result_var)
# Match output dtype with input dtype
if result_var.dtype != original_dtypes[0]:
assert original_dtypes[0] is not None
self.post_loop_combine.writeline(
f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})"
)
return result_var
def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype):