mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Inductor/Triton] Upcast FP16/BF16 math reductions to FP32 (#141052)
Summary: Triton compiler does not automatically promote fp16/bf16 reductions to fp32 accumulation. This will result in significant accuracy issue. This diff will upcast the input to FP32 for all math reductions `["welford_reduce", "welford_combine", "prod", "sum", "xor_sum"]` Test Plan: CI ``` python test/inductor/test_torchinductor.py TritonCodeGenTests.test_low_precision_reduction ``` Differential Revision: D65965032 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141052 Approved by: https://github.com/blaine-rister
This commit is contained in:
committed by
PyTorch MergeBot
parent
816328fa51
commit
417d9c3522
@ -95,6 +95,35 @@ class TestCase(InductorTestCase):
|
||||
fp32_cast_in_code = "to(tl.float32)" in code
|
||||
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
|
||||
|
||||
@requires_gpu()
|
||||
@parametrize("input_shape", [(32, 32), (32, 128), (256, 32)])
|
||||
@parametrize(
|
||||
"reduction_func",
|
||||
[
|
||||
torch.prod,
|
||||
torch.sum,
|
||||
torch.argmax,
|
||||
torch.argmin,
|
||||
torch.min,
|
||||
torch.max,
|
||||
],
|
||||
)
|
||||
@parametrize("input_dtype", [torch.float16, torch.bfloat16])
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
def test_low_precision_reduction(
|
||||
self, input_shape, reduction_func, input_dtype
|
||||
):
|
||||
@torch.compile
|
||||
def func(a, b, c, d):
|
||||
return reduction_func(a * b * c * d)
|
||||
|
||||
inps = (torch.rand(input_shape, device=GPU_TYPE, dtype=input_dtype),) * 4
|
||||
with config.patch("triton.codegen_upcast_to_fp32", False):
|
||||
func_opt = torch._dynamo.optimize("inductor")(func)
|
||||
code = run_and_get_triton_code(func_opt, *inps)
|
||||
self.assertTrue(".to(tl.float32)" in code)
|
||||
self.assertEqual(func(*inps), func_opt(*inps))
|
||||
|
||||
def test_op_dtype_support(self):
|
||||
"""
|
||||
Triton codegen upcasts values to float32 for certain ops.
|
||||
|
@ -30,6 +30,7 @@ from sympy.printing.precedence import PRECEDENCE
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.utils import identity, preserve_rng_state
|
||||
from torch._prims_common import is_integer_dtype
|
||||
@ -2284,6 +2285,27 @@ class TritonKernel(SIMDKernel):
|
||||
reduction_type: ReductionType,
|
||||
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
||||
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
||||
def maybe_upcast(value: CSEVariable) -> CSEVariable:
|
||||
# Math reductions in FP16/BF16 are less accurate because the Triton compiler does not
|
||||
# automatically promote to FP32 for accumulation. Additionally, max/min reductions
|
||||
# do not support FP16/BF16. We manually promote to FP32 here.
|
||||
return (
|
||||
ops.to_dtype(value, torch.float32)
|
||||
if value.dtype
|
||||
in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
else value
|
||||
)
|
||||
|
||||
original_dtypes = [val.dtype for val in pytree.tree_leaves(value)]
|
||||
value = pytree.tree_map(maybe_upcast, value)
|
||||
if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes):
|
||||
# Only promote FB16/BF16; do not promote other integer/boolean dtypes
|
||||
src_dtype = torch.promote_types(src_dtype, torch.float32)
|
||||
dtype = torch.promote_types(dtype, torch.float32)
|
||||
|
||||
assert self.inside_reduction
|
||||
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
|
||||
self.filter_masks(masks)
|
||||
@ -2545,10 +2567,30 @@ class TritonKernel(SIMDKernel):
|
||||
if isinstance(result_var, tuple):
|
||||
assert all(isinstance(x, TritonCSEVariable) for x in result_var)
|
||||
self.outside_loop_vars.update(result_var)
|
||||
|
||||
# Match output dtype with input dtype
|
||||
if reduction_type == "welford_reduce":
|
||||
assert len(original_dtypes) == 1
|
||||
original_dtypes = len(result_var) * original_dtypes
|
||||
|
||||
assert len(result_var) == len(original_dtypes)
|
||||
for var, orig_dtype in zip(result_var, original_dtypes):
|
||||
assert orig_dtype is not None
|
||||
if var.dtype != orig_dtype:
|
||||
self.post_loop_combine.writeline(
|
||||
f"{var} = {var}.to({triton_compute_type(orig_dtype)})"
|
||||
)
|
||||
else:
|
||||
assert isinstance(result_var, TritonCSEVariable)
|
||||
self.outside_loop_vars.add(result_var)
|
||||
|
||||
# Match output dtype with input dtype
|
||||
if result_var.dtype != original_dtypes[0]:
|
||||
assert original_dtypes[0] is not None
|
||||
self.post_loop_combine.writeline(
|
||||
f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})"
|
||||
)
|
||||
|
||||
return result_var
|
||||
|
||||
def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype):
|
||||
|
Reference in New Issue
Block a user