[ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998)

A few UT failures are caused by `HIPBLASLT_ALLOW_TF32`

Fixes #157094
Fixes #157093
Fixes #157092
Fixes #157091
Fixes #157064
Fixes #157063
Fixes #157062
Fixes #157061
Fixes #157042
Fixes #157041
Fixes #157039
Fixes #157004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Xinya Zhang
2025-09-18 13:53:48 +00:00
committed by PyTorch MergeBot
parent 14f8d86136
commit e769026bcb
10 changed files with 48 additions and 196 deletions

View File

@ -180,7 +180,7 @@ void Context::setUserEnabledNNPACK(bool e) {
}
bool Context::allowTF32CuDNN(const std::string& op) const {
if (op.size() == 0){
if (op.empty()){
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
TORCH_CHECK(
@ -281,9 +281,6 @@ bool Context::userEnabledOverrideableSDP() const {
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
#ifdef USE_ROCM
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
#endif
bool Context::checkCuBLASConfigDeterministic() {
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
@ -343,12 +340,6 @@ void Context::setImmediateMiopen(bool b) {
}
bool Context::allowTF32CuBLAS() const {
#ifdef USE_ROCM
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
return false;
}
#endif
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
TORCH_CHECK(
@ -362,14 +353,6 @@ bool Context::allowTF32CuBLAS() const {
}
void Context::setAllowTF32CuBLAS(bool b) {
#ifdef USE_ROCM
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
return;
}
#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
}
@ -443,7 +426,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string&
std::string msg;
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
for (auto p : iterp->second) {
for (const auto& p : iterp->second) {
msg += p;
msg += " ";
}

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: dynamo"]
import contextlib
import os
import torch
import torch.fx
@ -196,21 +195,6 @@ class GraphRegionTrackerTests(TestCase):
)
def test_mismatched_global_state(self):
@contextlib.contextmanager
def _hip_allow_tf32():
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]
def inner_fn(x, y):
x1 = x * 1
y1 = y + 1
@ -251,31 +235,29 @@ class GraphRegionTrackerTests(TestCase):
def reset_default_dtype():
torch.set_default_dtype(old_dtype)
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
with tf32_ctx():
for ctx in [
lambda: torch.set_grad_enabled(False),
torch.autograd.grad_mode.inference_mode,
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
),
# lambda: torch.set_num_threads(2), : Unsupported
(set_default_dtype_bfloat16, reset_default_dtype),
(
lambda: torch.use_deterministic_algorithms(True),
lambda: torch.use_deterministic_algorithms(False),
),
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
create_toggle_fns("allow_tf32"),
]:
self.assertExpectedInline(
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
for ctx in [
lambda: torch.set_grad_enabled(False),
torch.autograd.grad_mode.inference_mode,
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
),
# lambda: torch.set_num_threads(2), : Unsupported
(set_default_dtype_bfloat16, reset_default_dtype),
(
lambda: torch.use_deterministic_algorithms(True),
lambda: torch.use_deterministic_algorithms(False),
),
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
create_toggle_fns("allow_tf32"),
]:
self.assertExpectedInline(
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
)
)
def test_mutation_tracking_simple(self):
def fn(x, y, z):

View File

@ -8478,43 +8478,24 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
def fn(x):
return x + 1
import contextlib
@contextlib.contextmanager
def _hip_allow_tf32():
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
with tf32_ctx():
initial_state = read_state()
y = torch.randn(10)
try:
for round in range(3):
for i in range(len(initial_state)):
new_state = [False] * len(initial_state)
new_state[i] = True
write_state(new_state)
assert read_state() == new_state
last_state.clear()
fn(y)
assert last_state == new_state
if round == 0:
assert cnt == i + 1
else:
assert cnt == len(initial_state)
finally:
write_state(initial_state)
initial_state = read_state()
y = torch.randn(10)
try:
for round in range(3):
for i in range(len(initial_state)):
new_state = [False] * len(initial_state)
new_state[i] = True
write_state(new_state)
assert read_state() == new_state
last_state.clear()
fn(y)
assert last_state == new_state
if round == 0:
assert cnt == i + 1
else:
assert cnt == len(initial_state)
finally:
write_state(initial_state)
def test_grad_state_mutated(self):
prior = torch.is_grad_enabled()

View File

@ -43,9 +43,6 @@ if IS_WINDOWS and IS_CI:
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
if torch.version.hip:
torch.set_float32_matmul_precision("highest")
else:

View File

@ -109,9 +109,6 @@ class TestCaseBase(TestCase):
if HAS_GPU:
cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision()
cls.prior_default_device = torch.get_default_device()
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
if torch.version.hip:
torch.set_float32_matmul_precision("highest")
else:

View File

@ -759,53 +759,7 @@ print(t.is_pinned())
torch._C._cuda_clearCublasWorkspaces()
@contextlib.contextmanager
def _hip_allow_tf32(self):
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]
@unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
def test_hipblaslt_allow_tf32(self):
tf32_ctx = self._hip_allow_tf32
with tf32_ctx():
os.environ["HIPBLASLT_ALLOW_TF32"] = "0"
# Save original value of allow_tf32
orig = torch.backends.cuda.matmul.allow_tf32
# If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp
# then matmul.allow_tf32 will return False after this point even if
# HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed.
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
# Toggle torch.backends.cuda.matmul.allow_tf32 couple of times.
torch.backends.cuda.matmul.allow_tf32 = not orig
test1 = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = orig
test2 = torch.backends.cuda.matmul.allow_tf32
self.assertNotEqual(test1, test2)
# Restore original value of allow_tf32
torch.backends.cuda.matmul.allow_tf32 = orig
def test_cublas_allow_tf32_get_set(self):
"""
We only turn on TF32 for MI300 with a special env var. This is because TF32
is only available in MI300+ and is in experimental mode (hipblaslt support
is current WIP)
"""
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
with tf32_ctx():
self._test_cublas_allow_tf32_get_set_inner()
def _test_cublas_allow_tf32_get_set_inner(self):
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
)
@ -820,12 +774,6 @@ print(t.is_pinned())
torch.backends.cuda.matmul.allow_tf32 = orig
def test_float32_matmul_precision_get_set(self):
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
with tf32_ctx():
self._test_float32_matmul_precision_get_set_inner()
def _test_float32_matmul_precision_get_set_inner(self):
orig = torch.get_float32_matmul_precision()
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]

View File

@ -109,22 +109,6 @@ def get_tunableop_untuned_filename():
return untuned_filename
class TestLinalg(TestCase):
@contextlib.contextmanager
def _hip_allow_tf32(self):
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+. Environment variable will be removed in the future.
import os
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]
def setUp(self):
super().setUp()
torch.backends.cuda.matmul.allow_tf32 = False
@ -5542,13 +5526,8 @@ class TestLinalg(TestCase):
@runOnRocmArch(MI300_ARCH)
@dtypes(torch.float)
def test_tf32_tunableop(self, device, dtype):
# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+. Eventually this flag will go away.
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
try:
with self._tunableop_ctx(), tf32_ctx():
with self._tunableop_ctx():
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.tunable.set_rotating_buffer_size(0)
@ -5611,13 +5590,8 @@ class TestLinalg(TestCase):
# This test is the offline version of test_tf32_tunableop
import os
# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+. Eventually this flag will go away.
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
try:
with self._tunableop_ctx(), tf32_ctx():
with self._tunableop_ctx():
torch.backends.cuda.matmul.allow_tf32 = True
ordinal = torch.cuda.current_device()
torch.cuda.tunable.set_rotating_buffer_size(0)

View File

@ -51,7 +51,6 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
tf32_on_and_off,
tf32_enabled,
ROCM_VERSION,
)
if TEST_FAIRSEQ:
@ -340,7 +339,7 @@ class TestTransformers(NNTestCase):
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
@tf32_on_and_off(0.001)
@parametrize("attn_mask_dim", [2, 3, None])
@parametrize("key_padding_mask_dim", [2, None])
@parametrize("mask_dtype", [torch.bool, torch.float32])
@ -524,7 +523,7 @@ class TestTransformers(NNTestCase):
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
self.assertEqual(fastpath_output_expanded, slowpath_output)
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
@tf32_on_and_off(0.001)
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [False])
@ -1110,7 +1109,7 @@ class TestTransformers(NNTestCase):
return_all_hiddens=False,
)[0]
@tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
@tf32_on_and_off(0.003)
@parametrize("input_dim,attn_mask_dim,is_causal",
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],

View File

@ -591,7 +591,6 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
transA = layout[1] == "T"
dtype = dtype_dict.get(data_type)
if data_type == "tf32":
# User must still set HIPBLASLT_ALLOW_TF32=1
torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = False

View File

@ -181,9 +181,6 @@ def tf32_off():
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
if torch.version.hip:
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
old_precision = self.precision
try:
@ -192,11 +189,6 @@ def tf32_on(self, tf32_precision=1e-5):
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
yield
finally:
if torch.version.hip:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
self.precision = old_precision
@ -246,7 +238,7 @@ def tf32_enabled():
# if device is specified, it will check if device is cuda
# if dtype is specified, it will check if dtype is float32 or complex64
# tf32 and fp32 are different only when all the three checks pass
def tf32_on_and_off(tf32_precision=1e-5, only_if=True):
def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True):
def with_tf32_disabled(self, function_call):
with tf32_off():
function_call()