mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][inductor] tl.dot(..., allow_tf32=...) -> tl.dot(..., input_precision=...) (#160711)
allow_tf32 is deprecated. Also, this will make it easier to support tf32x3 (i.e. #160359). dashboard results on h100 show no change: [inference](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2011%20Aug%202025%2017%3A01%3A22%20GMT&stopTime=Mon%2C%2018%20Aug%202025%2017%3A01%3A22%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/399/orig&lCommit=ce12d0fd751a733f22b5bdda00bd58d323e0a526&rBranch=main&rCommit=e444cd24d48b3a46f067974f2cc157f5ed27709f), [training](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2011%20Aug%202025%2017%3A01%3A22%20GMT&stopTime=Mon%2C%2018%20Aug%202025%2017%3A01%3A22%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/399/orig&lCommit=ce12d0fd751a733f22b5bdda00bd58d323e0a526&rBranch=main&rCommit=e444cd24d48b3a46f067974f2cc157f5ed27709f) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160711 Approved by: https://github.com/PaulZhang12, https://github.com/njriasan
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d46aa736f
commit
8dbe7f99bd
@ -1630,7 +1630,7 @@ class TritonTemplate(KernelTemplate):
|
||||
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
|
||||
# for one example issue with this problem.
|
||||
if torch.cuda.is_available() and not torch.cuda.is_tf32_supported():
|
||||
kwargs["ALLOW_TF32"] = "False"
|
||||
kwargs["FLOAT32_PRECISION"] = '"ieee"'
|
||||
|
||||
if call_sizes is None:
|
||||
call_sizes = layout.size
|
||||
@ -1763,7 +1763,7 @@ class TritonTemplate(KernelTemplate):
|
||||
"num_stages": num_stages,
|
||||
"num_warps": num_warps,
|
||||
"GROUP_M": kwargs.get("GROUP_M", -1),
|
||||
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
||||
"float32_precision": str(kwargs.get("FLOAT32_PRECISION", None)),
|
||||
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
||||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||
@ -2395,12 +2395,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
important_keys = [
|
||||
"ACC_TYPE",
|
||||
"ALLOW_TF32",
|
||||
"BLOCK_K",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"EVEN_K",
|
||||
"GROUP_M",
|
||||
"FLOAT32_PRECISION",
|
||||
"USE_FAST_ACCUM",
|
||||
"num_stages",
|
||||
"num_warps",
|
||||
|
Reference in New Issue
Block a user