Files
pytorch/torch/testing/_internal/common_cuda.py
Xinya Zhang e769026bcb [ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998)
A few UT failures are caused by `HIPBLASLT_ALLOW_TF32`

Fixes #157094
Fixes #157093
Fixes #157092
Fixes #157091
Fixes #157064
Fixes #157063
Fixes #157062
Fixes #157061
Fixes #157042
Fixes #157041
Fixes #157039
Fixes #157004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-09-18 13:53:48 +00:00

371 lines
16 KiB
Python

# mypy: ignore-errors
r"""This file is allowed to initialize CUDA context when imported."""
import functools
import torch
import torch.cuda
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS, IS_MACOS
import inspect
import contextlib
import os
import unittest
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
if TEST_WITH_ROCM:
TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
else:
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0))
SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
SM120OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (12, 0))
IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
and torch.cuda.get_device_capability()[1] > 0)
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0))
IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0))
def evaluate_gfx_arch_within(arch_list):
if not torch.cuda.is_available():
return False
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
# gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
# Hence the matching should be done reversely
return any(arch in effective_arch for arch in arch_list)
def CDNA3OrLater():
return evaluate_gfx_arch_within(["gfx940", "gfx941", "gfx942", "gfx950"])
def CDNA2OrLater():
return evaluate_gfx_arch_within(["gfx90a", "gfx942"])
def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater
return False
def evaluate_platform_supports_efficient_attention():
if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA:
return True
return False
def evaluate_platform_supports_cudnn_attention():
return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or
PLATFORM_SUPPORTS_CUDNN_ATTENTION or
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
def evaluate_platform_supports_fp8():
if torch.cuda.is_available():
if torch.version.hip:
archs = ['gfx94']
if ROCM_VERSION >= (6, 3):
archs.extend(['gfx120'])
if ROCM_VERSION >= (6, 5):
archs.append('gfx95')
for arch in archs:
if arch in torch.cuda.get_device_properties(0).gcnArchName:
return True
else:
return SM90OrLater or torch.cuda.get_device_capability() == (8, 9)
return False
def evaluate_platform_supports_fp8_grouped_gemm():
if torch.cuda.is_available():
if torch.version.hip:
if "USE_FBGEMM_GENAI" not in torch.__config__.show():
return False
archs = ['gfx942']
for arch in archs:
if arch in torch.cuda.get_device_properties(0).gcnArchName:
return True
else:
return SM90OrLater and not SM100OrLater
return False
def evaluate_platform_supports_mx_gemm():
if torch.cuda.is_available():
if torch.version.hip:
if ROCM_VERSION >= (7, 0):
return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName
else:
return SM100OrLater
return False
def evaluate_platform_supports_mxfp8_grouped_gemm():
if torch.cuda.is_available() and not torch.version.hip:
built_with_fbgemm_genai = "USE_FBGEMM_GENAI" in torch.__config__.show()
return built_with_fbgemm_genai and IS_SM100
return False
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm())
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm())
if TEST_NUMBA:
try:
import numba.cuda
TEST_NUMBA_CUDA = numba.cuda.is_available()
except Exception:
TEST_NUMBA_CUDA = False
TEST_NUMBA = False
else:
TEST_NUMBA_CUDA = False
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
# RNG have been initialized.
__cuda_ctx_rng_initialized = False
# after this call, CUDA context and RNG must have been initialized on each GPU
def initialize_cuda_context_rng():
global __cuda_ctx_rng_initialized
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
if not __cuda_ctx_rng_initialized:
# initialize cuda context and rng for memory tests
for i in range(torch.cuda.device_count()):
torch.randn(1, device=f"cuda:{i}")
__cuda_ctx_rng_initialized = True
@contextlib.contextmanager
def tf32_off():
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
try:
torch.backends.cuda.matmul.allow_tf32 = False
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
old_precision = self.precision
try:
torch.backends.cuda.matmul.allow_tf32 = True
self.precision = tf32_precision
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
self.precision = old_precision
@contextlib.contextmanager
def tf32_enabled():
"""
Context manager to temporarily enable TF32 for CUDA operations.
Restores the previous TF32 state after exiting the context.
"""
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
try:
torch.backends.cuda.matmul.allow_tf32 = True
with torch.backends.cudnn.flags(
enabled=None, benchmark=None, deterministic=None, allow_tf32=True
):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
# This is a wrapper that wraps a test to run this test twice, one with
# allow_tf32=True, another with allow_tf32=False. When running with
# allow_tf32=True, it will use reduced precision as specified by the
# argument. For example:
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
# @tf32_on_and_off(0.005)
# def test_matmul(self, device, dtype):
# a = ...; b = ...;
# c = torch.matmul(a, b)
# self.assertEqual(c, expected)
# In the above example, when testing torch.float32 and torch.complex64 on CUDA
# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
# precision to check values.
#
# This decorator can be used for function with or without device/dtype, such as
# @tf32_on_and_off(0.005)
# def test_my_op(self)
# @tf32_on_and_off(0.005)
# def test_my_op(self, device)
# @tf32_on_and_off(0.005)
# def test_my_op(self, device, dtype)
# @tf32_on_and_off(0.005)
# def test_my_op(self, dtype)
# if neither device nor dtype is specified, it will check if the system has ampere device
# if device is specified, it will check if device is cuda
# if dtype is specified, it will check if dtype is float32 or complex64
# tf32 and fp32 are different only when all the three checks pass
def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True):
def with_tf32_disabled(self, function_call):
with tf32_off():
function_call()
def with_tf32_enabled(self, function_call):
with tf32_on(self, tf32_precision):
function_call()
def wrapper(f):
params = inspect.signature(f).parameters
arg_names = tuple(params.keys())
@functools.wraps(f)
def wrapped(*args, **kwargs):
kwargs.update(zip(arg_names, args))
cond = torch.cuda.is_tf32_supported() and only_if
if 'device' in kwargs:
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
if 'dtype' in kwargs:
cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
if cond:
with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
else:
f(**kwargs)
return wrapped
return wrapper
# This is a wrapper that wraps a test to run it with TF32 turned off.
# This wrapper is designed to be used when a test uses matmul or convolutions
# but the purpose of that test is not testing matmul or convolutions.
# Disabling TF32 will enforce torch.float tensors to be always computed
# at full precision.
def with_tf32_off(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
with tf32_off():
return f(*args, **kwargs)
return wrapped
def _get_magma_version():
if 'Magma' not in torch.__config__.show():
return (0, 0)
position = torch.__config__.show().find('Magma ')
version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
return tuple(int(x) for x in version_str.split("."))
def _get_torch_cuda_version():
if torch.version.cuda is None:
return (0, 0)
cuda_version = str(torch.version.cuda)
return tuple(int(x) for x in cuda_version.split("."))
def _get_torch_rocm_version():
if not TEST_WITH_ROCM or torch.version.hip is None:
return (0, 0)
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
return tuple(int(x) for x in rocm_version.split("."))
def _check_cusparse_generic_available():
return not TEST_WITH_ROCM
def _check_hipsparse_generic_available():
if not TEST_WITH_ROCM:
return False
if not torch.version.hip:
return False
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
# Shared by test_torch.py and test_multigpu.py
def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
# Create a module+optimizer that will use scaling, and a control module+optimizer
# that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
with torch.no_grad():
for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
s.copy_(c)
kwargs = {"lr": 1.0}
if optimizer_kwargs is not None:
kwargs.update(optimizer_kwargs)
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
return mod_control, mod_scaling, opt_control, opt_scaling
# Shared by test_torch.py, test_cuda.py and test_multigpu.py
def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
loss_fn = torch.nn.MSELoss().to(device)
skip_iter = 2
return _create_scaling_models_optimizers(
device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
) + (data, loss_fn, skip_iter)
def xfailIfSM89(func):
return func if not IS_SM89 else unittest.expectedFailure(func)
def xfailIfSM100OrLater(func):
return func if not SM100OrLater else unittest.expectedFailure(func)
def xfailIfSM120OrLater(func):
return func if not SM120OrLater else unittest.expectedFailure(func)
def xfailIfDistributedNotSupported(func):
return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func)
# Importing this module should NOT eagerly initialize CUDA
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
assert not torch.cuda.is_initialized()