mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[inductor triton] Disable incorrect TF32 usage on CUDA capability < 8 (#145684)
Triton 2.2 and greater have a bug where allowing TF32 generation for a GPU that does not support TF32 will cause code generation errors. Patch around this problem by: 1. Adding a function to `torch.cuda` that determines whether CUDA hardware is capable of using the TF32 format. 2. Using that function to explicitly disable TF32 generation when calling Triton, where needed. To demonstrate that this fix works, try running `test/inductor/test_max_autotune.py` on a GPU with CUDA compute capability < 8 (e.g. any NVIDIA consumer GPU) without this fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145684 Approved by: https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ffed44b42
commit
5aa5a5763e
@ -56,7 +56,7 @@ from torch.testing._internal.common_device_type import (
|
||||
import torch.backends.quantized
|
||||
import torch.testing._internal.data
|
||||
from torch.testing._internal.common_cuda import (
|
||||
tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN, TEST_MULTIGPU,
|
||||
tf32_on_and_off, TEST_CUDNN, TEST_MULTIGPU,
|
||||
_create_scaling_case, _create_scaling_models_optimizers)
|
||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||
from torch.testing._internal.common_dtype import (
|
||||
@ -79,7 +79,7 @@ assert torch.get_default_dtype() is torch.float32
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def torch_vital_set(value):
|
||||
|
Reference in New Issue
Block a user