David Berard
2025-08-16 10:37:36 -07:00
committed by PyTorch MergeBot
parent 1d46aa736f
commit 8dbe7f99bd
8 changed files with 37 additions and 25 deletions

View File

@ -1630,7 +1630,7 @@ class TritonTemplate(KernelTemplate):
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
# for one example issue with this problem.
if torch.cuda.is_available() and not torch.cuda.is_tf32_supported():
kwargs["ALLOW_TF32"] = "False"
kwargs["FLOAT32_PRECISION"] = '"ieee"'
if call_sizes is None:
call_sizes = layout.size
@ -1763,7 +1763,7 @@ class TritonTemplate(KernelTemplate):
"num_stages": num_stages,
"num_warps": num_warps,
"GROUP_M": kwargs.get("GROUP_M", -1),
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"float32_precision": str(kwargs.get("FLOAT32_PRECISION", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
@ -2395,12 +2395,12 @@ class AlgorithmSelectorCache(PersistentCache):
important_keys = [
"ACC_TYPE",
"ALLOW_TF32",
"BLOCK_K",
"BLOCK_M",
"BLOCK_N",
"EVEN_K",
"GROUP_M",
"FLOAT32_PRECISION",
"USE_FAST_ACCUM",
"num_stages",
"num_warps",