David Berard
2025-08-16 10:37:36 -07:00
committed by PyTorch MergeBot
parent 1d46aa736f
commit 8dbe7f99bd
8 changed files with 37 additions and 25 deletions

View File

@ -1378,7 +1378,7 @@ class TestMaxAutotune(TestCase):
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[10,30],
'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]",
'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity',
'kwargs':{'EVEN_K':False,'ALLOW_TF32':True,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32',
'kwargs':{'EVEN_K':False,'FLOAT32_PRECISION':'"tf32"','USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32',
'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
expected = expected.replace("cuda", GPU_TYPE)
@ -1417,7 +1417,7 @@ class TestMaxAutotune(TestCase):
"[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"],
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94],
'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0,
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True,
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'FLOAT32_PRECISION':'"tf32"',
'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
expected = expected.replace("cuda", GPU_TYPE)
self.assertExpectedInline(

View File

@ -3081,7 +3081,7 @@ class MutationTests(torch._inductor.test_case.TestCase):
# Compute output
w = tl.load(w1_block_ptr)
b = tl.load(b1_block_ptr)
o = tl.dot(x, w, allow_tf32=False)
o = tl.dot(x, w, input_precision="ieee")
o += b[None, :]
# Store output

View File

@ -95,7 +95,7 @@ bmm_template = TritonTemplate(
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk

View File

@ -85,7 +85,7 @@ LOOP_BODY_2D = """
)
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
acc += tl.dot(matrix_x, matrix_w, input_precision=FLOAT32_PRECISION)
"""
"""
@ -214,7 +214,7 @@ LOOP_BODY_3D = """
)
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
acc += tl.dot(matrix_x, matrix_w, input_precision=FLOAT32_PRECISION)
"""
conv3d_template = TritonTemplate(
@ -390,6 +390,11 @@ def channels_last_order(rank):
return order
def _get_float32_precision():
result = "tf32" if torch.backends.cuda.matmul.allow_tf32 else "ieee"
return f'"{result}"'
def convert_1x1_conv_to_mm(x, weight, bias):
# special case for 1x1 convolution, which is actually just a matmul
rank = len(weight.get_size())
@ -611,7 +616,7 @@ def convolution(
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
FLOAT32_PRECISION=_get_float32_precision(),
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
@ -634,7 +639,7 @@ def convolution(
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
FLOAT32_PRECISION=_get_float32_precision(),
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,

View File

@ -136,9 +136,9 @@ mm_template = TritonTemplate(
{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}
{% if USE_FAST_ACCUM %}
acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
acc = tl.dot(a, b, acc, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
{% else %}
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
{% endif %}
# rematerialize rm and rn to save registers
@ -211,9 +211,9 @@ mm_template = TritonTemplate(
idx_n = offs_b_n[None, :]
{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}
{% if USE_FAST_ACCUM %}
acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
acc = tl.dot(a, b, acc, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
{% else %}
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
{% endif %}
# rematerialize rm and rn to save registers
@ -347,7 +347,7 @@ persistent_tma_mm_template = TritonTemplate(
acc += tl.dot(
a if A_ROW_MAJOR else a.T,
b if B_ROW_MAJOR else b.T,
allow_tf32=ALLOW_TF32,
input_precision=FLOAT32_PRECISION,
)
if ki == k_tiles - 1:

View File

@ -90,7 +90,7 @@ mm_plus_mm_template = TritonTemplate(
else:
a = tl.load(A, mask=rk[None, :] < k1, other=0.)
b = tl.load(B, mask=rk[:, None] < k1, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
@ -103,7 +103,7 @@ mm_plus_mm_template = TritonTemplate(
else:
c = tl.load(C, mask=rk[None, :] < k2, other=0.)
d = tl.load(D, mask=rk[:, None] < k2, other=0.)
acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
acc += tl.dot(c, d, input_precision=FLOAT32_PRECISION)
C += BLOCK_K * stride_ck
D += BLOCK_K * stride_dk

View File

@ -1630,7 +1630,7 @@ class TritonTemplate(KernelTemplate):
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
# for one example issue with this problem.
if torch.cuda.is_available() and not torch.cuda.is_tf32_supported():
kwargs["ALLOW_TF32"] = "False"
kwargs["FLOAT32_PRECISION"] = '"ieee"'
if call_sizes is None:
call_sizes = layout.size
@ -1763,7 +1763,7 @@ class TritonTemplate(KernelTemplate):
"num_stages": num_stages,
"num_warps": num_warps,
"GROUP_M": kwargs.get("GROUP_M", -1),
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"float32_precision": str(kwargs.get("FLOAT32_PRECISION", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
@ -2395,12 +2395,12 @@ class AlgorithmSelectorCache(PersistentCache):
important_keys = [
"ACC_TYPE",
"ALLOW_TF32",
"BLOCK_K",
"BLOCK_M",
"BLOCK_N",
"EVEN_K",
"GROUP_M",
"FLOAT32_PRECISION",
"USE_FAST_ACCUM",
"num_stages",
"num_warps",

View File

@ -1316,6 +1316,19 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
)
yield template_kwargs
@staticmethod
def _get_input_precision(
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
) -> str:
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
not inductor_config.force_same_precision
or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0)
)
result = "tf32" if allow_tf32 else "ieee"
# wrap in quotes, because the string will be dropped into the templates
return f'"{result}"'
def _convert_config_to_template_kwargs(
self,
triton_config: TritonConfig,
@ -1335,16 +1348,10 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
== triton_config.kwargs["BLOCK_K"]
)
# Calculate allow_tf32
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
not inductor_config.force_same_precision
or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0)
)
# Build options dict
options_dict = dict(
EVEN_K=even_k_symbolic,
ALLOW_TF32=allow_tf32,
FLOAT32_PRECISION=MMTemplateConfigMixin._get_input_precision(m, n, k),
USE_FAST_ACCUM=False, # Option for _scaled_mm
ACC_TYPE=self._get_acc_type(layout.dtype),
num_stages=triton_config.num_stages,