From e769026bcbbf6a2fc4b9be69e866c644e07b5df7 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 18 Sep 2025 13:53:48 +0000 Subject: [PATCH] [ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998) A few UT failures are caused by `HIPBLASLT_ALLOW_TF32` Fixes #157094 Fixes #157093 Fixes #157092 Fixes #157091 Fixes #157064 Fixes #157063 Fixes #157062 Fixes #157061 Fixes #157042 Fixes #157041 Fixes #157039 Fixes #157004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- aten/src/ATen/Context.cpp | 21 +------- test/dynamo/test_graph_region_tracker.py | 62 +++++++++--------------- test/dynamo/test_misc.py | 55 +++++++-------------- test/inductor/test_flex_decoding.py | 3 -- test/inductor/test_padding.py | 3 -- test/test_cuda.py | 52 -------------------- test/test_linalg.py | 30 +----------- test/test_transformers.py | 7 ++- torch/cuda/tunable.py | 1 - torch/testing/_internal/common_cuda.py | 10 +--- 10 files changed, 48 insertions(+), 196 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 4d48084b0ab8..7a8d02be530e 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -180,7 +180,7 @@ void Context::setUserEnabledNNPACK(bool e) { } bool Context::allowTF32CuDNN(const std::string& op) const { - if (op.size() == 0){ + if (op.empty()){ bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32"; bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32"; TORCH_CHECK( @@ -281,9 +281,6 @@ bool Context::userEnabledOverrideableSDP() const { static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG"; static constexpr const std::array cublas_deterministic_configs = {":4096:8", ":16:8"}; -#ifdef USE_ROCM -static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32"; -#endif bool Context::checkCuBLASConfigDeterministic() { // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config @@ -343,12 +340,6 @@ void Context::setImmediateMiopen(bool b) { } bool Context::allowTF32CuBLAS() const { -#ifdef USE_ROCM - const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); - if (allow_tf32 != true) { - return false; - } -#endif bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32"; TORCH_CHECK( @@ -362,14 +353,6 @@ bool Context::allowTF32CuBLAS() const { } void Context::setAllowTF32CuBLAS(bool b) { -#ifdef USE_ROCM - const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); - if (allow_tf32 != true) { - C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. " - << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it."; - return; - } -#endif float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee"); } @@ -443,7 +426,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string& std::string msg; auto iterp = _fp32_precisions.find(backend); TORCH_CHECK(iterp != _fp32_precisions.end()); - for (auto p : iterp->second) { + for (const auto& p : iterp->second) { msg += p; msg += " "; } diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py index e930ff787a9a..ce456596fd55 100644 --- a/test/dynamo/test_graph_region_tracker.py +++ b/test/dynamo/test_graph_region_tracker.py @@ -1,6 +1,5 @@ # Owner(s): ["module: dynamo"] import contextlib -import os import torch import torch.fx @@ -196,21 +195,6 @@ class GraphRegionTrackerTests(TestCase): ) def test_mismatched_global_state(self): - @contextlib.contextmanager - def _hip_allow_tf32(): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - def inner_fn(x, y): x1 = x * 1 y1 = y + 1 @@ -251,31 +235,29 @@ class GraphRegionTrackerTests(TestCase): def reset_default_dtype(): torch.set_default_dtype(old_dtype) - tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - with tf32_ctx(): - for ctx in [ - lambda: torch.set_grad_enabled(False), - torch.autograd.grad_mode.inference_mode, - lambda: torch.autograd.graph.disable_saved_tensors_hooks( - "This is not supported" - ), - # lambda: torch.set_num_threads(2), : Unsupported - (set_default_dtype_bfloat16, reset_default_dtype), - ( - lambda: torch.use_deterministic_algorithms(True), - lambda: torch.use_deterministic_algorithms(False), - ), - # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), - # lambda: torch.use_deterministic_algorithms(False)), : Unsupported - create_toggle_fns("allow_bf16_reduced_precision_reduction"), - create_toggle_fns("allow_fp16_reduced_precision_reduction"), - create_toggle_fns("allow_tf32"), - ]: - self.assertExpectedInline( - self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), - """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ + for ctx in [ + lambda: torch.set_grad_enabled(False), + torch.autograd.grad_mode.inference_mode, + lambda: torch.autograd.graph.disable_saved_tensors_hooks( + "This is not supported" + ), + # lambda: torch.set_num_threads(2), : Unsupported + (set_default_dtype_bfloat16, reset_default_dtype), + ( + lambda: torch.use_deterministic_algorithms(True), + lambda: torch.use_deterministic_algorithms(False), + ), + # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), + # lambda: torch.use_deterministic_algorithms(False)), : Unsupported + create_toggle_fns("allow_bf16_reduced_precision_reduction"), + create_toggle_fns("allow_fp16_reduced_precision_reduction"), + create_toggle_fns("allow_tf32"), + ]: + self.assertExpectedInline( + self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), + """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ [['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", - ) + ) def test_mutation_tracking_simple(self): def fn(x, y, z): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 85831321f09a..e09191261e3d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -8478,43 +8478,24 @@ utils_device.CURRENT_DEVICE == None""".split("\n"): def fn(x): return x + 1 - import contextlib - - @contextlib.contextmanager - def _hip_allow_tf32(): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - - tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - with tf32_ctx(): - initial_state = read_state() - y = torch.randn(10) - try: - for round in range(3): - for i in range(len(initial_state)): - new_state = [False] * len(initial_state) - new_state[i] = True - write_state(new_state) - assert read_state() == new_state - last_state.clear() - fn(y) - assert last_state == new_state - if round == 0: - assert cnt == i + 1 - else: - assert cnt == len(initial_state) - finally: - write_state(initial_state) + initial_state = read_state() + y = torch.randn(10) + try: + for round in range(3): + for i in range(len(initial_state)): + new_state = [False] * len(initial_state) + new_state[i] = True + write_state(new_state) + assert read_state() == new_state + last_state.clear() + fn(y) + assert last_state == new_state + if round == 0: + assert cnt == i + 1 + else: + assert cnt == len(initial_state) + finally: + write_state(initial_state) def test_grad_state_mutated(self): prior = torch.is_grad_enabled() diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 120d8d36b439..849aefff8a96 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -43,9 +43,6 @@ if IS_WINDOWS and IS_CI: Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) -# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. -# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the -# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. if torch.version.hip: torch.set_float32_matmul_precision("highest") else: diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 9ef3a18e2423..c67bde87a369 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -109,9 +109,6 @@ class TestCaseBase(TestCase): if HAS_GPU: cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision() cls.prior_default_device = torch.get_default_device() - # In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. - # In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the - # logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. if torch.version.hip: torch.set_float32_matmul_precision("highest") else: diff --git a/test/test_cuda.py b/test/test_cuda.py index 7bd310042862..b809fc521600 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -759,53 +759,7 @@ print(t.is_pinned()) torch._C._cuda_clearCublasWorkspaces() - @contextlib.contextmanager - def _hip_allow_tf32(self): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - - @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") - def test_hipblaslt_allow_tf32(self): - tf32_ctx = self._hip_allow_tf32 - with tf32_ctx(): - os.environ["HIPBLASLT_ALLOW_TF32"] = "0" - # Save original value of allow_tf32 - orig = torch.backends.cuda.matmul.allow_tf32 - # If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp - # then matmul.allow_tf32 will return False after this point even if - # HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed. - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - # Toggle torch.backends.cuda.matmul.allow_tf32 couple of times. - torch.backends.cuda.matmul.allow_tf32 = not orig - test1 = torch.backends.cuda.matmul.allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = orig - test2 = torch.backends.cuda.matmul.allow_tf32 - self.assertNotEqual(test1, test2) - # Restore original value of allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = orig - def test_cublas_allow_tf32_get_set(self): - """ - We only turn on TF32 for MI300 with a special env var. This is because TF32 - is only available in MI300+ and is in experimental mode (hipblaslt support - is current WIP) - """ - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - - with tf32_ctx(): - self._test_cublas_allow_tf32_get_set_inner() - - def _test_cublas_allow_tf32_get_set_inner(self): skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] ) @@ -820,12 +774,6 @@ print(t.is_pinned()) torch.backends.cuda.matmul.allow_tf32 = orig def test_float32_matmul_precision_get_set(self): - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - - with tf32_ctx(): - self._test_float32_matmul_precision_get_set_inner() - - def _test_float32_matmul_precision_get_set_inner(self): orig = torch.get_float32_matmul_precision() skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] diff --git a/test/test_linalg.py b/test/test_linalg.py index ffae8ac18da2..4f8780dfc30a 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -109,22 +109,6 @@ def get_tunableop_untuned_filename(): return untuned_filename class TestLinalg(TestCase): - @contextlib.contextmanager - def _hip_allow_tf32(self): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Environment variable will be removed in the future. - import os - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - def setUp(self): super().setUp() torch.backends.cuda.matmul.allow_tf32 = False @@ -5542,13 +5526,8 @@ class TestLinalg(TestCase): @runOnRocmArch(MI300_ARCH) @dtypes(torch.float) def test_tf32_tunableop(self, device, dtype): - # Test TunableOp with TF32. Supported by hipblasLT on MI300+. - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Eventually this flag will go away. - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - try: - with self._tunableop_ctx(), tf32_ctx(): + with self._tunableop_ctx(): torch.backends.cuda.matmul.allow_tf32 = True torch.cuda.tunable.set_rotating_buffer_size(0) @@ -5611,13 +5590,8 @@ class TestLinalg(TestCase): # This test is the offline version of test_tf32_tunableop import os - # Test TunableOp with TF32. Supported by hipblasLT on MI300+. - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Eventually this flag will go away. - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - try: - with self._tunableop_ctx(), tf32_ctx(): + with self._tunableop_ctx(): torch.backends.cuda.matmul.allow_tf32 = True ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) diff --git a/test/test_transformers.py b/test/test_transformers.py index c58fe05d37be..7c2060034710 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -51,7 +51,6 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_CUDNN_ATTENTION, tf32_on_and_off, tf32_enabled, - ROCM_VERSION, ) if TEST_FAIRSEQ: @@ -340,7 +339,7 @@ class TestTransformers(NNTestCase): l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") - @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.001) @parametrize("attn_mask_dim", [2, 3, None]) @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) @@ -524,7 +523,7 @@ class TestTransformers(NNTestCase): slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) self.assertEqual(fastpath_output_expanded, slowpath_output) - @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.001) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -1110,7 +1109,7 @@ class TestTransformers(NNTestCase): return_all_hiddens=False, )[0] - @tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.003) @parametrize("input_dim,attn_mask_dim,is_causal", [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index 99f469d46dc1..a1fbd4fdddc2 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -591,7 +591,6 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None: transA = layout[1] == "T" dtype = dtype_dict.get(data_type) if data_type == "tf32": - # User must still set HIPBLASLT_ALLOW_TF32=1 torch.backends.cuda.matmul.allow_tf32 = True else: torch.backends.cuda.matmul.allow_tf32 = False diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index be284429114f..846d2b407684 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -181,9 +181,6 @@ def tf32_off(): @contextlib.contextmanager def tf32_on(self, tf32_precision=1e-5): - if torch.version.hip: - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 old_precision = self.precision try: @@ -192,11 +189,6 @@ def tf32_on(self, tf32_precision=1e-5): with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): yield finally: - if torch.version.hip: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul self.precision = old_precision @@ -246,7 +238,7 @@ def tf32_enabled(): # if device is specified, it will check if device is cuda # if dtype is specified, it will check if dtype is float32 or complex64 # tf32 and fp32 are different only when all the three checks pass -def tf32_on_and_off(tf32_precision=1e-5, only_if=True): +def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True): def with_tf32_disabled(self, function_call): with tf32_off(): function_call()