mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor triton] Disable incorrect TF32 usage on CUDA capability < 8 (#145684)
Triton 2.2 and greater have a bug where allowing TF32 generation for a GPU that does not support TF32 will cause code generation errors. Patch around this problem by: 1. Adding a function to `torch.cuda` that determines whether CUDA hardware is capable of using the TF32 format. 2. Using that function to explicitly disable TF32 generation when calling Triton, where needed. To demonstrate that this fix works, try running `test/inductor/test_max_autotune.py` on a GPU with CUDA compute capability < 8 (e.g. any NVIDIA consumer GPU) without this fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145684 Approved by: https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ffed44b42
commit
5aa5a5763e
@ -29,6 +29,7 @@ torch.cuda
|
||||
ipc_collect
|
||||
is_available
|
||||
is_initialized
|
||||
is_tf32_supported
|
||||
memory_usage
|
||||
set_device
|
||||
set_stream
|
||||
|
@ -11,12 +11,7 @@ import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import (
|
||||
TEST_CUDA,
|
||||
TEST_CUDNN,
|
||||
tf32_is_not_fp32,
|
||||
tf32_on_and_off,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off
|
||||
from torch.testing._internal.common_device_type import (
|
||||
disablecuDNN,
|
||||
disableMkldnn,
|
||||
@ -65,7 +60,7 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
||||
|
||||
|
||||
if TEST_SCIPY:
|
||||
@ -2077,7 +2072,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
if mode == "same":
|
||||
actual = actual[:5, :5, :10]
|
||||
|
||||
if tf32_is_not_fp32() and (
|
||||
if torch.cuda.is_tf32_supported() and (
|
||||
dtype == torch.float or dtype == torch.complex64
|
||||
):
|
||||
self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
|
||||
@ -3920,7 +3915,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
inp, w, None, (1, 1), (0, 0), (1, 1), 1
|
||||
)
|
||||
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
|
||||
if tf32_is_not_fp32() and dtype == torch.float:
|
||||
if torch.cuda.is_tf32_supported() and dtype == torch.float:
|
||||
self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
|
||||
else:
|
||||
self.assertEqual(conv2d_out.relu(), cudnn_out)
|
||||
@ -3958,7 +3953,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
|
||||
if tf32_is_not_fp32() and dtype == torch.float:
|
||||
if torch.cuda.is_tf32_supported() and dtype == torch.float:
|
||||
self.assertEqual(
|
||||
F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006
|
||||
)
|
||||
|
@ -51,11 +51,11 @@ import torch.testing._internal.hypothesis_utils as hu
|
||||
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck, \
|
||||
GRADCHECK_NONDET_TOL
|
||||
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
|
||||
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
|
||||
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_off, tf32_on
|
||||
from torch.types import _TensorOrTensors
|
||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -7326,7 +7326,7 @@ def add_test(test, decorator=None):
|
||||
kwargs['extra_args'] = test.extra_args
|
||||
|
||||
if 'dtype' in get_function_arglist(test.test_cuda):
|
||||
if tf32_is_not_fp32() and test.with_tf32:
|
||||
if torch.cuda.is_tf32_supported() and test.with_tf32:
|
||||
|
||||
def with_tf32_off(self, test=test, kwargs=kwargs):
|
||||
with tf32_off():
|
||||
@ -7369,7 +7369,7 @@ def add_test(test, decorator=None):
|
||||
with tf32_off():
|
||||
test.test_cuda(self, **kwargs)
|
||||
|
||||
if tf32_is_not_fp32() and test.with_tf32:
|
||||
if torch.cuda.is_tf32_supported() and test.with_tf32:
|
||||
add(cuda_test_name + '_fp32', with_tf32_off)
|
||||
|
||||
def with_tf32_on(self, test=test, kwargs=kwargs):
|
||||
|
@ -56,7 +56,7 @@ from torch.testing._internal.common_device_type import (
|
||||
import torch.backends.quantized
|
||||
import torch.testing._internal.data
|
||||
from torch.testing._internal.common_cuda import (
|
||||
tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN, TEST_MULTIGPU,
|
||||
tf32_on_and_off, TEST_CUDNN, TEST_MULTIGPU,
|
||||
_create_scaling_case, _create_scaling_models_optimizers)
|
||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||
from torch.testing._internal.common_dtype import (
|
||||
@ -79,7 +79,7 @@ assert torch.get_default_dtype() is torch.float32
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def torch_vital_set(value):
|
||||
|
@ -11,7 +11,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch._C._dynamo.guards import assert_size_stride
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import tf32_is_not_fp32
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
@ -31,7 +30,7 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
||||
if TEST_SCIPY:
|
||||
import scipy.ndimage
|
||||
import scipy.signal
|
||||
|
@ -1063,6 +1063,14 @@ class TritonTemplate(KernelTemplate):
|
||||
"""
|
||||
assert self.template, "requires jinja2"
|
||||
defines = StringIO()
|
||||
|
||||
# HACK: Triton currently breaks if TF32 floats are requested, but the CUDA
|
||||
# capability doesn't support them. This is a bug in Triton, but for now we'll
|
||||
# patch around it here. See https://github.com/triton-lang/triton/issues/3011
|
||||
# for one example issue with this problem.
|
||||
if not torch.cuda.is_tf32_supported():
|
||||
kwargs["ALLOW_TF32"] = "False"
|
||||
|
||||
for name, val in kwargs.items():
|
||||
defines.write(f"{name} : tl.constexpr = {val}\n")
|
||||
defines = defines.getvalue()
|
||||
|
@ -168,6 +168,18 @@ def _check_bf16_tensor_supported(device: _device_t):
|
||||
return False
|
||||
|
||||
|
||||
def is_tf32_supported() -> bool:
|
||||
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype tf32."""
|
||||
# Check for ROCm. If true, return false, since PyTorch does not currently support
|
||||
# tf32 on ROCm.
|
||||
if torch.version.hip:
|
||||
return False
|
||||
|
||||
# Otherwise, tf32 is supported on CUDA platforms that natively (i.e. no emulation)
|
||||
# support bfloat16.
|
||||
return is_bf16_supported(including_emulation=False)
|
||||
|
||||
|
||||
def _sleep(cycles):
|
||||
torch._C._cuda_sleep(cycles)
|
||||
|
||||
@ -1711,6 +1723,7 @@ __all__ = [
|
||||
"is_bf16_supported",
|
||||
"is_current_stream_capturing",
|
||||
"is_initialized",
|
||||
"is_tf32_supported",
|
||||
"jiterator",
|
||||
"list_gpu_processes",
|
||||
"make_graphed_callables",
|
||||
|
@ -120,19 +120,6 @@ def initialize_cuda_context_rng():
|
||||
__cuda_ctx_rng_initialized = True
|
||||
|
||||
|
||||
# Test whether hardware TF32 math mode enabled. It is enabled only on:
|
||||
# - CUDA >= 11
|
||||
# - arch >= Ampere
|
||||
def tf32_is_not_fp32():
|
||||
if not torch.cuda.is_available() or torch.version.cuda is None:
|
||||
return False
|
||||
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
||||
return False
|
||||
if int(torch.version.cuda.split('.')[0]) < 11:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tf32_off():
|
||||
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
||||
@ -220,7 +207,7 @@ def tf32_on_and_off(tf32_precision=1e-5):
|
||||
def wrapped(*args, **kwargs):
|
||||
for k, v in zip(arg_names, args):
|
||||
kwargs[k] = v
|
||||
cond = tf32_is_not_fp32()
|
||||
cond = torch.cuda.is_tf32_supported()
|
||||
if 'device' in kwargs:
|
||||
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
|
||||
if 'dtype' in kwargs:
|
||||
|
Reference in New Issue
Block a user