[inductor triton] Disable incorrect TF32 usage on CUDA capability < 8 (#145684)

Triton 2.2 and greater have a bug where allowing TF32 generation for a GPU that does not support TF32 will cause code generation errors. Patch around this problem by:

1. Adding a function to `torch.cuda` that determines whether CUDA hardware is capable of using the TF32 format.
2. Using that function to explicitly disable TF32 generation when calling Triton, where needed.

To demonstrate that this fix works, try running `test/inductor/test_max_autotune.py` on a GPU with CUDA compute capability < 8 (e.g. any NVIDIA consumer GPU) without this fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145684
Approved by: https://github.com/eqy
This commit is contained in:
Benjamin Glass
2025-01-27 22:10:04 +00:00
committed by PyTorch MergeBot
parent 1ffed44b42
commit 5aa5a5763e
8 changed files with 35 additions and 32 deletions

View File

@ -29,6 +29,7 @@ torch.cuda
ipc_collect
is_available
is_initialized
is_tf32_supported
memory_usage
set_device
set_stream

View File

@ -11,12 +11,7 @@ import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
TEST_CUDA,
TEST_CUDNN,
tf32_is_not_fp32,
tf32_on_and_off,
)
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off
from torch.testing._internal.common_device_type import (
disablecuDNN,
disableMkldnn,
@ -65,7 +60,7 @@ from torch.testing._internal.common_utils import (
)
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
if TEST_SCIPY:
@ -2077,7 +2072,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
if mode == "same":
actual = actual[:5, :5, :10]
if tf32_is_not_fp32() and (
if torch.cuda.is_tf32_supported() and (
dtype == torch.float or dtype == torch.complex64
):
self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
@ -3920,7 +3915,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
inp, w, None, (1, 1), (0, 0), (1, 1), 1
)
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
if tf32_is_not_fp32() and dtype == torch.float:
if torch.cuda.is_tf32_supported() and dtype == torch.float:
self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
else:
self.assertEqual(conv2d_out.relu(), cudnn_out)
@ -3958,7 +3953,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
)
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
if tf32_is_not_fp32() and dtype == torch.float:
if torch.cuda.is_tf32_supported() and dtype == torch.float:
self.assertEqual(
F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006
)

View File

@ -51,11 +51,11 @@ import torch.testing._internal.hypothesis_utils as hu
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck, \
GRADCHECK_NONDET_TOL
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_off, tf32_on
from torch.types import _TensorOrTensors
from torch.testing._internal.common_mkldnn import bf32_on_and_off
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -7326,7 +7326,7 @@ def add_test(test, decorator=None):
kwargs['extra_args'] = test.extra_args
if 'dtype' in get_function_arglist(test.test_cuda):
if tf32_is_not_fp32() and test.with_tf32:
if torch.cuda.is_tf32_supported() and test.with_tf32:
def with_tf32_off(self, test=test, kwargs=kwargs):
with tf32_off():
@ -7369,7 +7369,7 @@ def add_test(test, decorator=None):
with tf32_off():
test.test_cuda(self, **kwargs)
if tf32_is_not_fp32() and test.with_tf32:
if torch.cuda.is_tf32_supported() and test.with_tf32:
add(cuda_test_name + '_fp32', with_tf32_off)
def with_tf32_on(self, test=test, kwargs=kwargs):

View File

@ -56,7 +56,7 @@ from torch.testing._internal.common_device_type import (
import torch.backends.quantized
import torch.testing._internal.data
from torch.testing._internal.common_cuda import (
tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN, TEST_MULTIGPU,
tf32_on_and_off, TEST_CUDNN, TEST_MULTIGPU,
_create_scaling_case, _create_scaling_models_optimizers)
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.testing._internal.common_dtype import (
@ -79,7 +79,7 @@ assert torch.get_default_dtype() is torch.float32
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
@contextlib.contextmanager
def torch_vital_set(value):

View File

@ -11,7 +11,6 @@ import torch.nn as nn
import torch.nn.functional as F
from torch._C._dynamo.guards import assert_size_stride
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import tf32_is_not_fp32
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
@ -31,7 +30,7 @@ from torch.testing._internal.common_utils import (
)
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
if TEST_SCIPY:
import scipy.ndimage
import scipy.signal

View File

@ -1063,6 +1063,14 @@ class TritonTemplate(KernelTemplate):
"""
assert self.template, "requires jinja2"
defines = StringIO()
# HACK: Triton currently breaks if TF32 floats are requested, but the CUDA
# capability doesn't support them. This is a bug in Triton, but for now we'll
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
# for one example issue with this problem.
if not torch.cuda.is_tf32_supported():
kwargs["ALLOW_TF32"] = "False"
for name, val in kwargs.items():
defines.write(f"{name} : tl.constexpr = {val}\n")
defines = defines.getvalue()

View File

@ -168,6 +168,18 @@ def _check_bf16_tensor_supported(device: _device_t):
return False
def is_tf32_supported() -> bool:
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype tf32."""
# Check for ROCm. If true, return false, since PyTorch does not currently support
# tf32 on ROCm.
if torch.version.hip:
return False
# Otherwise, tf32 is supported on CUDA platforms that natively (i.e. no emulation)
# support bfloat16.
return is_bf16_supported(including_emulation=False)
def _sleep(cycles):
torch._C._cuda_sleep(cycles)
@ -1711,6 +1723,7 @@ __all__ = [
"is_bf16_supported",
"is_current_stream_capturing",
"is_initialized",
"is_tf32_supported",
"jiterator",
"list_gpu_processes",
"make_graphed_callables",

View File

@ -120,19 +120,6 @@ def initialize_cuda_context_rng():
__cuda_ctx_rng_initialized = True
# Test whether hardware TF32 math mode enabled. It is enabled only on:
# - CUDA >= 11
# - arch >= Ampere
def tf32_is_not_fp32():
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split('.')[0]) < 11:
return False
return True
@contextlib.contextmanager
def tf32_off():
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
@ -220,7 +207,7 @@ def tf32_on_and_off(tf32_precision=1e-5):
def wrapped(*args, **kwargs):
for k, v in zip(arg_names, args):
kwargs[k] = v
cond = tf32_is_not_fp32()
cond = torch.cuda.is_tf32_supported()
if 'device' in kwargs:
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
if 'dtype' in kwargs: