From 8dbe7f99bd707ee28ae12ecb9cab54e1785bf13e Mon Sep 17 00:00:00 2001 From: David Berard Date: Sat, 16 Aug 2025 10:37:36 -0700 Subject: [PATCH] [BE][inductor] tl.dot(..., allow_tf32=...) -> tl.dot(..., input_precision=...) (#160711) allow_tf32 is deprecated. Also, this will make it easier to support tf32x3 (i.e. #160359). dashboard results on h100 show no change: [inference](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2011%20Aug%202025%2017%3A01%3A22%20GMT&stopTime=Mon%2C%2018%20Aug%202025%2017%3A01%3A22%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/399/orig&lCommit=ce12d0fd751a733f22b5bdda00bd58d323e0a526&rBranch=main&rCommit=e444cd24d48b3a46f067974f2cc157f5ed27709f), [training](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2011%20Aug%202025%2017%3A01%3A22%20GMT&stopTime=Mon%2C%2018%20Aug%202025%2017%3A01%3A22%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/399/orig&lCommit=ce12d0fd751a733f22b5bdda00bd58d323e0a526&rBranch=main&rCommit=e444cd24d48b3a46f067974f2cc157f5ed27709f) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160711 Approved by: https://github.com/PaulZhang12, https://github.com/njriasan --- test/inductor/test_max_autotune.py | 4 ++-- test/inductor/test_triton_kernels.py | 2 +- torch/_inductor/kernel/bmm.py | 2 +- torch/_inductor/kernel/conv.py | 13 +++++++++---- torch/_inductor/kernel/mm.py | 10 +++++----- torch/_inductor/kernel/mm_plus_mm.py | 4 ++-- torch/_inductor/select_algorithm.py | 6 +++--- torch/_inductor/template_heuristics.py | 21 ++++++++++++++------- 8 files changed, 37 insertions(+), 25 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 151f1c3ec592..55fd364f9b91 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1378,7 +1378,7 @@ class TestMaxAutotune(TestCase): 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[10,30], 'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]", 'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity', - 'kwargs':{'EVEN_K':False,'ALLOW_TF32':True,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32', + 'kwargs':{'EVEN_K':False,'FLOAT32_PRECISION':'"tf32"','USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32', 'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}""" expected = expected.replace("cuda", GPU_TYPE) @@ -1417,7 +1417,7 @@ class TestMaxAutotune(TestCase): "[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"], 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94], 'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0, - 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True, + 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'FLOAT32_PRECISION':'"tf32"', 'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}""" expected = expected.replace("cuda", GPU_TYPE) self.assertExpectedInline( diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index fc9f92477c79..8fb22219302b 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -3081,7 +3081,7 @@ class MutationTests(torch._inductor.test_case.TestCase): # Compute output w = tl.load(w1_block_ptr) b = tl.load(b1_block_ptr) - o = tl.dot(x, w, allow_tf32=False) + o = tl.dot(x, w, input_precision="ieee") o += b[None, :] # Store output diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index 92822ecc310b..947175af0470 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -95,7 +95,7 @@ bmm_template = TritonTemplate( else: a = tl.load(A, mask=rk[None, :] < k, other=0.) b = tl.load(B, mask=rk[:, None] < k, other=0.) - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 6b9e9a1a32e7..3b40bfc21b5e 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -85,7 +85,7 @@ LOOP_BODY_2D = """ ) mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) - acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) + acc += tl.dot(matrix_x, matrix_w, input_precision=FLOAT32_PRECISION) """ """ @@ -214,7 +214,7 @@ LOOP_BODY_3D = """ ) mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) - acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) + acc += tl.dot(matrix_x, matrix_w, input_precision=FLOAT32_PRECISION) """ conv3d_template = TritonTemplate( @@ -390,6 +390,11 @@ def channels_last_order(rank): return order +def _get_float32_precision(): + result = "tf32" if torch.backends.cuda.matmul.allow_tf32 else "ieee" + return f'"{result}"' + + def convert_1x1_conv_to_mm(x, weight, bias): # special case for 1x1 convolution, which is actually just a matmul rank = len(weight.get_size()) @@ -611,7 +616,7 @@ def convolution( # TODO(jansel): try unroll for bigger kernels once fixed: # https://github.com/triton-lang/triton/issues/1254 UNROLL=is_ones(kernel_shape), - ALLOW_TF32=torch.backends.cudnn.allow_tf32, + FLOAT32_PRECISION=_get_float32_precision(), num_stages=cfg.num_stages, num_warps=cfg.num_warps, **cfg.kwargs, @@ -634,7 +639,7 @@ def convolution( # TODO(jansel): try unroll for bigger kernels once fixed: # https://github.com/triton-lang/triton/issues/1254 UNROLL=is_ones(kernel_shape), - ALLOW_TF32=torch.backends.cudnn.allow_tf32, + FLOAT32_PRECISION=_get_float32_precision(), num_stages=cfg.num_stages, num_warps=cfg.num_warps, **cfg.kwargs, diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 7a8a4e1cc32a..e4303879dc87 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -136,9 +136,9 @@ mm_template = TritonTemplate( {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} {% if USE_FAST_ACCUM %} - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + acc = tl.dot(a, b, acc, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE) {% else %} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE) {% endif %} # rematerialize rm and rn to save registers @@ -211,9 +211,9 @@ mm_template = TritonTemplate( idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} {% if USE_FAST_ACCUM %} - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + acc = tl.dot(a, b, acc, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE) {% else %} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE) {% endif %} # rematerialize rm and rn to save registers @@ -347,7 +347,7 @@ persistent_tma_mm_template = TritonTemplate( acc += tl.dot( a if A_ROW_MAJOR else a.T, b if B_ROW_MAJOR else b.T, - allow_tf32=ALLOW_TF32, + input_precision=FLOAT32_PRECISION, ) if ki == k_tiles - 1: diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index df3e8fcf1e65..d5ab1d2b83e9 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -90,7 +90,7 @@ mm_plus_mm_template = TritonTemplate( else: a = tl.load(A, mask=rk[None, :] < k1, other=0.) b = tl.load(B, mask=rk[:, None] < k1, other=0.) - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk @@ -103,7 +103,7 @@ mm_plus_mm_template = TritonTemplate( else: c = tl.load(C, mask=rk[None, :] < k2, other=0.) d = tl.load(D, mask=rk[:, None] < k2, other=0.) - acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + acc += tl.dot(c, d, input_precision=FLOAT32_PRECISION) C += BLOCK_K * stride_ck D += BLOCK_K * stride_dk diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 01337fc0d30b..1f42cf99028c 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1630,7 +1630,7 @@ class TritonTemplate(KernelTemplate): # patch around it here. See https://github.com/triton-lang/triton/issues/3011 # for one example issue with this problem. if torch.cuda.is_available() and not torch.cuda.is_tf32_supported(): - kwargs["ALLOW_TF32"] = "False" + kwargs["FLOAT32_PRECISION"] = '"ieee"' if call_sizes is None: call_sizes = layout.size @@ -1763,7 +1763,7 @@ class TritonTemplate(KernelTemplate): "num_stages": num_stages, "num_warps": num_warps, "GROUP_M": kwargs.get("GROUP_M", -1), - "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), + "float32_precision": str(kwargs.get("FLOAT32_PRECISION", None)), "acc_type": str(kwargs.get("ACC_TYPE", None)), "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), "waves_per_eu": kwargs.get("waves_per_eu", 0), @@ -2395,12 +2395,12 @@ class AlgorithmSelectorCache(PersistentCache): important_keys = [ "ACC_TYPE", - "ALLOW_TF32", "BLOCK_K", "BLOCK_M", "BLOCK_N", "EVEN_K", "GROUP_M", + "FLOAT32_PRECISION", "USE_FAST_ACCUM", "num_stages", "num_warps", diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index 68b304fdbc61..1b87a61c35d1 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -1316,6 +1316,19 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics): ) yield template_kwargs + @staticmethod + def _get_input_precision( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> str: + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0) + ) + result = "tf32" if allow_tf32 else "ieee" + + # wrap in quotes, because the string will be dropped into the templates + return f'"{result}"' + def _convert_config_to_template_kwargs( self, triton_config: TritonConfig, @@ -1335,16 +1348,10 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics): == triton_config.kwargs["BLOCK_K"] ) - # Calculate allow_tf32 - allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( - not inductor_config.force_same_precision - or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0) - ) - # Build options dict options_dict = dict( EVEN_K=even_k_symbolic, - ALLOW_TF32=allow_tf32, + FLOAT32_PRECISION=MMTemplateConfigMixin._get_input_precision(m, n, k), USE_FAST_ACCUM=False, # Option for _scaled_mm ACC_TYPE=self._get_acc_type(layout.dtype), num_stages=triton_config.num_stages,