[Inductor/Triton] Upcast FP16/BF16 math reductions to FP32 (#141052)

Summary:
Triton compiler does not automatically promote fp16/bf16 reductions to fp32  accumulation. This will result in significant accuracy issue.

This diff will upcast the input to FP32 for all math reductions `["welford_reduce", "welford_combine", "prod", "sum", "xor_sum"]`

Test Plan:
CI
```
python test/inductor/test_torchinductor.py TritonCodeGenTests.test_low_precision_reduction
```

Differential Revision: D65965032

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141052
Approved by: https://github.com/blaine-rister
This commit is contained in:
Xinran / Allan Rui
2025-01-04 07:57:08 +00:00
committed by PyTorch MergeBot
parent 816328fa51
commit 417d9c3522
2 changed files with 71 additions and 0 deletions

View File

@ -95,6 +95,35 @@ class TestCase(InductorTestCase):
fp32_cast_in_code = "to(tl.float32)" in code
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
@requires_gpu()
@parametrize("input_shape", [(32, 32), (32, 128), (256, 32)])
@parametrize(
"reduction_func",
[
torch.prod,
torch.sum,
torch.argmax,
torch.argmin,
torch.min,
torch.max,
],
)
@parametrize("input_dtype", [torch.float16, torch.bfloat16])
@config.patch("triton.use_block_ptr", True)
def test_low_precision_reduction(
self, input_shape, reduction_func, input_dtype
):
@torch.compile
def func(a, b, c, d):
return reduction_func(a * b * c * d)
inps = (torch.rand(input_shape, device=GPU_TYPE, dtype=input_dtype),) * 4
with config.patch("triton.codegen_upcast_to_fp32", False):
func_opt = torch._dynamo.optimize("inductor")(func)
code = run_and_get_triton_code(func_opt, *inps)
self.assertTrue(".to(tl.float32)" in code)
self.assertEqual(func(*inps), func_opt(*inps))
def test_op_dtype_support(self):
"""
Triton codegen upcasts values to float32 for certain ops.

View File

@ -30,6 +30,7 @@ from sympy.printing.precedence import PRECEDENCE
import torch
import torch._logging
import torch.utils._pytree as pytree
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import identity, preserve_rng_state
from torch._prims_common import is_integer_dtype
@ -2284,6 +2285,27 @@ class TritonKernel(SIMDKernel):
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
def maybe_upcast(value: CSEVariable) -> CSEVariable:
# Math reductions in FP16/BF16 are less accurate because the Triton compiler does not
# automatically promote to FP32 for accumulation. Additionally, max/min reductions
# do not support FP16/BF16. We manually promote to FP32 here.
return (
ops.to_dtype(value, torch.float32)
if value.dtype
in [
torch.float16,
torch.bfloat16,
]
else value
)
original_dtypes = [val.dtype for val in pytree.tree_leaves(value)]
value = pytree.tree_map(maybe_upcast, value)
if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes):
# Only promote FB16/BF16; do not promote other integer/boolean dtypes
src_dtype = torch.promote_types(src_dtype, torch.float32)
dtype = torch.promote_types(dtype, torch.float32)
assert self.inside_reduction
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
self.filter_masks(masks)
@ -2545,10 +2567,30 @@ class TritonKernel(SIMDKernel):
if isinstance(result_var, tuple):
assert all(isinstance(x, TritonCSEVariable) for x in result_var)
self.outside_loop_vars.update(result_var)
# Match output dtype with input dtype
if reduction_type == "welford_reduce":
assert len(original_dtypes) == 1
original_dtypes = len(result_var) * original_dtypes
assert len(result_var) == len(original_dtypes)
for var, orig_dtype in zip(result_var, original_dtypes):
assert orig_dtype is not None
if var.dtype != orig_dtype:
self.post_loop_combine.writeline(
f"{var} = {var}.to({triton_compute_type(orig_dtype)})"
)
else:
assert isinstance(result_var, TritonCSEVariable)
self.outside_loop_vars.add(result_var)
# Match output dtype with input dtype
if result_var.dtype != original_dtypes[0]:
assert original_dtypes[0] is not None
self.post_loop_combine.writeline(
f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})"
)
return result_var
def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype):