mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][inductor] tl.dot(..., allow_tf32=...) -> tl.dot(..., input_precision=...) (#160711)
allow_tf32 is deprecated. Also, this will make it easier to support tf32x3 (i.e. #160359). dashboard results on h100 show no change: [inference](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2011%20Aug%202025%2017%3A01%3A22%20GMT&stopTime=Mon%2C%2018%20Aug%202025%2017%3A01%3A22%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/399/orig&lCommit=ce12d0fd751a733f22b5bdda00bd58d323e0a526&rBranch=main&rCommit=e444cd24d48b3a46f067974f2cc157f5ed27709f), [training](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2011%20Aug%202025%2017%3A01%3A22%20GMT&stopTime=Mon%2C%2018%20Aug%202025%2017%3A01%3A22%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/399/orig&lCommit=ce12d0fd751a733f22b5bdda00bd58d323e0a526&rBranch=main&rCommit=e444cd24d48b3a46f067974f2cc157f5ed27709f) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160711 Approved by: https://github.com/PaulZhang12, https://github.com/njriasan
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d46aa736f
commit
8dbe7f99bd
@ -1378,7 +1378,7 @@ class TestMaxAutotune(TestCase):
|
||||
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[10,30],
|
||||
'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]",
|
||||
'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity',
|
||||
'kwargs':{'EVEN_K':False,'ALLOW_TF32':True,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32',
|
||||
'kwargs':{'EVEN_K':False,'FLOAT32_PRECISION':'"tf32"','USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32',
|
||||
'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
|
||||
|
||||
expected = expected.replace("cuda", GPU_TYPE)
|
||||
@ -1417,7 +1417,7 @@ class TestMaxAutotune(TestCase):
|
||||
"[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"],
|
||||
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94],
|
||||
'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0,
|
||||
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True,
|
||||
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'FLOAT32_PRECISION':'"tf32"',
|
||||
'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
|
||||
expected = expected.replace("cuda", GPU_TYPE)
|
||||
self.assertExpectedInline(
|
||||
|
||||
@ -3081,7 +3081,7 @@ class MutationTests(torch._inductor.test_case.TestCase):
|
||||
# Compute output
|
||||
w = tl.load(w1_block_ptr)
|
||||
b = tl.load(b1_block_ptr)
|
||||
o = tl.dot(x, w, allow_tf32=False)
|
||||
o = tl.dot(x, w, input_precision="ieee")
|
||||
o += b[None, :]
|
||||
|
||||
# Store output
|
||||
|
||||
@ -95,7 +95,7 @@ bmm_template = TritonTemplate(
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ LOOP_BODY_2D = """
|
||||
)
|
||||
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
|
||||
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
|
||||
acc += tl.dot(matrix_x, matrix_w, input_precision=FLOAT32_PRECISION)
|
||||
"""
|
||||
|
||||
"""
|
||||
@ -214,7 +214,7 @@ LOOP_BODY_3D = """
|
||||
)
|
||||
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
|
||||
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
|
||||
acc += tl.dot(matrix_x, matrix_w, input_precision=FLOAT32_PRECISION)
|
||||
"""
|
||||
|
||||
conv3d_template = TritonTemplate(
|
||||
@ -390,6 +390,11 @@ def channels_last_order(rank):
|
||||
return order
|
||||
|
||||
|
||||
def _get_float32_precision():
|
||||
result = "tf32" if torch.backends.cuda.matmul.allow_tf32 else "ieee"
|
||||
return f'"{result}"'
|
||||
|
||||
|
||||
def convert_1x1_conv_to_mm(x, weight, bias):
|
||||
# special case for 1x1 convolution, which is actually just a matmul
|
||||
rank = len(weight.get_size())
|
||||
@ -611,7 +616,7 @@ def convolution(
|
||||
# TODO(jansel): try unroll for bigger kernels once fixed:
|
||||
# https://github.com/triton-lang/triton/issues/1254
|
||||
UNROLL=is_ones(kernel_shape),
|
||||
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
||||
FLOAT32_PRECISION=_get_float32_precision(),
|
||||
num_stages=cfg.num_stages,
|
||||
num_warps=cfg.num_warps,
|
||||
**cfg.kwargs,
|
||||
@ -634,7 +639,7 @@ def convolution(
|
||||
# TODO(jansel): try unroll for bigger kernels once fixed:
|
||||
# https://github.com/triton-lang/triton/issues/1254
|
||||
UNROLL=is_ones(kernel_shape),
|
||||
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
||||
FLOAT32_PRECISION=_get_float32_precision(),
|
||||
num_stages=cfg.num_stages,
|
||||
num_warps=cfg.num_warps,
|
||||
**cfg.kwargs,
|
||||
|
||||
@ -136,9 +136,9 @@ mm_template = TritonTemplate(
|
||||
{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}
|
||||
|
||||
{% if USE_FAST_ACCUM %}
|
||||
acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
|
||||
acc = tl.dot(a, b, acc, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
|
||||
{% else %}
|
||||
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
|
||||
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
|
||||
{% endif %}
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
@ -211,9 +211,9 @@ mm_template = TritonTemplate(
|
||||
idx_n = offs_b_n[None, :]
|
||||
{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}
|
||||
{% if USE_FAST_ACCUM %}
|
||||
acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
|
||||
acc = tl.dot(a, b, acc, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
|
||||
{% else %}
|
||||
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
|
||||
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION, out_dtype=ACC_TYPE)
|
||||
{% endif %}
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
@ -347,7 +347,7 @@ persistent_tma_mm_template = TritonTemplate(
|
||||
acc += tl.dot(
|
||||
a if A_ROW_MAJOR else a.T,
|
||||
b if B_ROW_MAJOR else b.T,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
input_precision=FLOAT32_PRECISION,
|
||||
)
|
||||
|
||||
if ki == k_tiles - 1:
|
||||
|
||||
@ -90,7 +90,7 @@ mm_plus_mm_template = TritonTemplate(
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k1, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k1, other=0.)
|
||||
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
acc += tl.dot(a, b, input_precision=FLOAT32_PRECISION)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
@ -103,7 +103,7 @@ mm_plus_mm_template = TritonTemplate(
|
||||
else:
|
||||
c = tl.load(C, mask=rk[None, :] < k2, other=0.)
|
||||
d = tl.load(D, mask=rk[:, None] < k2, other=0.)
|
||||
acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
|
||||
acc += tl.dot(c, d, input_precision=FLOAT32_PRECISION)
|
||||
C += BLOCK_K * stride_ck
|
||||
D += BLOCK_K * stride_dk
|
||||
|
||||
|
||||
@ -1630,7 +1630,7 @@ class TritonTemplate(KernelTemplate):
|
||||
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
|
||||
# for one example issue with this problem.
|
||||
if torch.cuda.is_available() and not torch.cuda.is_tf32_supported():
|
||||
kwargs["ALLOW_TF32"] = "False"
|
||||
kwargs["FLOAT32_PRECISION"] = '"ieee"'
|
||||
|
||||
if call_sizes is None:
|
||||
call_sizes = layout.size
|
||||
@ -1763,7 +1763,7 @@ class TritonTemplate(KernelTemplate):
|
||||
"num_stages": num_stages,
|
||||
"num_warps": num_warps,
|
||||
"GROUP_M": kwargs.get("GROUP_M", -1),
|
||||
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
||||
"float32_precision": str(kwargs.get("FLOAT32_PRECISION", None)),
|
||||
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
||||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||
@ -2395,12 +2395,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
important_keys = [
|
||||
"ACC_TYPE",
|
||||
"ALLOW_TF32",
|
||||
"BLOCK_K",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"EVEN_K",
|
||||
"GROUP_M",
|
||||
"FLOAT32_PRECISION",
|
||||
"USE_FAST_ACCUM",
|
||||
"num_stages",
|
||||
"num_warps",
|
||||
|
||||
@ -1316,6 +1316,19 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
)
|
||||
yield template_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_input_precision(
|
||||
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
|
||||
) -> str:
|
||||
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
||||
not inductor_config.force_same_precision
|
||||
or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0)
|
||||
)
|
||||
result = "tf32" if allow_tf32 else "ieee"
|
||||
|
||||
# wrap in quotes, because the string will be dropped into the templates
|
||||
return f'"{result}"'
|
||||
|
||||
def _convert_config_to_template_kwargs(
|
||||
self,
|
||||
triton_config: TritonConfig,
|
||||
@ -1335,16 +1348,10 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
== triton_config.kwargs["BLOCK_K"]
|
||||
)
|
||||
|
||||
# Calculate allow_tf32
|
||||
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
||||
not inductor_config.force_same_precision
|
||||
or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0)
|
||||
)
|
||||
|
||||
# Build options dict
|
||||
options_dict = dict(
|
||||
EVEN_K=even_k_symbolic,
|
||||
ALLOW_TF32=allow_tf32,
|
||||
FLOAT32_PRECISION=MMTemplateConfigMixin._get_input_precision(m, n, k),
|
||||
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
||||
ACC_TYPE=self._get_acc_type(layout.dtype),
|
||||
num_stages=triton_config.num_stages,
|
||||
|
||||
Reference in New Issue
Block a user