Revert "[BE][inductor] tl.dot(..., allow_tf32=...) -> tl.dot(..., input_precision=...) (#160711)"

This reverts commit 8dbe7f99bd707ee28ae12ecb9cab54e1785bf13e.

Reverted https://github.com/pytorch/pytorch/pull/160711 on behalf of https://github.com/davidberard98 due to internal failure - T235384144 - I'll revert while I investigate. ([comment](https://github.com/pytorch/pytorch/pull/160711#issuecomment-3215343200))
This commit is contained in:
PyTorch MergeBot
2025-08-22 19:10:35 +00:00
parent eba1ad09e4
commit 2c0650a00a
8 changed files with 25 additions and 37 deletions

View File

@ -1699,7 +1699,7 @@ class TritonTemplate(KernelTemplate):
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
# for one example issue with this problem.
if torch.cuda.is_available() and not torch.cuda.is_tf32_supported():
kwargs["FLOAT32_PRECISION"] = '"ieee"'
kwargs["ALLOW_TF32"] = "False"
if call_sizes is None:
call_sizes = layout.size
@ -1832,7 +1832,7 @@ class TritonTemplate(KernelTemplate):
"num_stages": num_stages,
"num_warps": num_warps,
"GROUP_M": kwargs.get("GROUP_M", -1),
"float32_precision": str(kwargs.get("FLOAT32_PRECISION", None)),
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
@ -2464,12 +2464,12 @@ class AlgorithmSelectorCache(PersistentCache):
important_keys = [
"ACC_TYPE",
"ALLOW_TF32",
"BLOCK_K",
"BLOCK_M",
"BLOCK_N",
"EVEN_K",
"GROUP_M",
"FLOAT32_PRECISION",
"USE_FAST_ACCUM",
"num_stages",
"num_warps",