mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor/Triton] Upcast FP16/BF16 math reductions to FP32 (#141052)
Summary: Triton compiler does not automatically promote fp16/bf16 reductions to fp32 accumulation. This will result in significant accuracy issue. This diff will upcast the input to FP32 for all math reductions `["welford_reduce", "welford_combine", "prod", "sum", "xor_sum"]` Test Plan: CI ``` python test/inductor/test_torchinductor.py TritonCodeGenTests.test_low_precision_reduction ``` Differential Revision: D65965032 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141052 Approved by: https://github.com/blaine-rister
This commit is contained in:
committed by
PyTorch MergeBot
parent
816328fa51
commit
417d9c3522
@ -30,6 +30,7 @@ from sympy.printing.precedence import PRECEDENCE
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.utils import identity, preserve_rng_state
|
||||
from torch._prims_common import is_integer_dtype
|
||||
@ -2284,6 +2285,27 @@ class TritonKernel(SIMDKernel):
|
||||
reduction_type: ReductionType,
|
||||
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
||||
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
||||
def maybe_upcast(value: CSEVariable) -> CSEVariable:
|
||||
# Math reductions in FP16/BF16 are less accurate because the Triton compiler does not
|
||||
# automatically promote to FP32 for accumulation. Additionally, max/min reductions
|
||||
# do not support FP16/BF16. We manually promote to FP32 here.
|
||||
return (
|
||||
ops.to_dtype(value, torch.float32)
|
||||
if value.dtype
|
||||
in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
else value
|
||||
)
|
||||
|
||||
original_dtypes = [val.dtype for val in pytree.tree_leaves(value)]
|
||||
value = pytree.tree_map(maybe_upcast, value)
|
||||
if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes):
|
||||
# Only promote FB16/BF16; do not promote other integer/boolean dtypes
|
||||
src_dtype = torch.promote_types(src_dtype, torch.float32)
|
||||
dtype = torch.promote_types(dtype, torch.float32)
|
||||
|
||||
assert self.inside_reduction
|
||||
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
|
||||
self.filter_masks(masks)
|
||||
@ -2545,10 +2567,30 @@ class TritonKernel(SIMDKernel):
|
||||
if isinstance(result_var, tuple):
|
||||
assert all(isinstance(x, TritonCSEVariable) for x in result_var)
|
||||
self.outside_loop_vars.update(result_var)
|
||||
|
||||
# Match output dtype with input dtype
|
||||
if reduction_type == "welford_reduce":
|
||||
assert len(original_dtypes) == 1
|
||||
original_dtypes = len(result_var) * original_dtypes
|
||||
|
||||
assert len(result_var) == len(original_dtypes)
|
||||
for var, orig_dtype in zip(result_var, original_dtypes):
|
||||
assert orig_dtype is not None
|
||||
if var.dtype != orig_dtype:
|
||||
self.post_loop_combine.writeline(
|
||||
f"{var} = {var}.to({triton_compute_type(orig_dtype)})"
|
||||
)
|
||||
else:
|
||||
assert isinstance(result_var, TritonCSEVariable)
|
||||
self.outside_loop_vars.add(result_var)
|
||||
|
||||
# Match output dtype with input dtype
|
||||
if result_var.dtype != original_dtypes[0]:
|
||||
assert original_dtypes[0] is not None
|
||||
self.post_loop_combine.writeline(
|
||||
f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})"
|
||||
)
|
||||
|
||||
return result_var
|
||||
|
||||
def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype):
|
||||
|
Reference in New Issue
Block a user