mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[BE][inductor] tl.dot(..., allow_tf32=...) -> tl.dot(..., input_precision=...) (#160711)"
This reverts commit 8dbe7f99bd707ee28ae12ecb9cab54e1785bf13e. Reverted https://github.com/pytorch/pytorch/pull/160711 on behalf of https://github.com/davidberard98 due to internal failure - T235384144 - I'll revert while I investigate. ([comment](https://github.com/pytorch/pytorch/pull/160711#issuecomment-3215343200))
This commit is contained in:
@ -1699,7 +1699,7 @@ class TritonTemplate(KernelTemplate):
|
||||
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
|
||||
# for one example issue with this problem.
|
||||
if torch.cuda.is_available() and not torch.cuda.is_tf32_supported():
|
||||
kwargs["FLOAT32_PRECISION"] = '"ieee"'
|
||||
kwargs["ALLOW_TF32"] = "False"
|
||||
|
||||
if call_sizes is None:
|
||||
call_sizes = layout.size
|
||||
@ -1832,7 +1832,7 @@ class TritonTemplate(KernelTemplate):
|
||||
"num_stages": num_stages,
|
||||
"num_warps": num_warps,
|
||||
"GROUP_M": kwargs.get("GROUP_M", -1),
|
||||
"float32_precision": str(kwargs.get("FLOAT32_PRECISION", None)),
|
||||
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
||||
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
||||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||
@ -2464,12 +2464,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
important_keys = [
|
||||
"ACC_TYPE",
|
||||
"ALLOW_TF32",
|
||||
"BLOCK_K",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"EVEN_K",
|
||||
"GROUP_M",
|
||||
"FLOAT32_PRECISION",
|
||||
"USE_FAST_ACCUM",
|
||||
"num_stages",
|
||||
"num_warps",
|
||||
|
Reference in New Issue
Block a user