mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998)
A few UT failures are caused by `HIPBLASLT_ALLOW_TF32` Fixes #157094 Fixes #157093 Fixes #157092 Fixes #157091 Fixes #157064 Fixes #157063 Fixes #157062 Fixes #157061 Fixes #157042 Fixes #157041 Fixes #157039 Fixes #157004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
14f8d86136
commit
e769026bcb
@ -180,7 +180,7 @@ void Context::setUserEnabledNNPACK(bool e) {
|
||||
}
|
||||
|
||||
bool Context::allowTF32CuDNN(const std::string& op) const {
|
||||
if (op.size() == 0){
|
||||
if (op.empty()){
|
||||
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
|
||||
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
|
||||
TORCH_CHECK(
|
||||
@ -281,9 +281,6 @@ bool Context::userEnabledOverrideableSDP() const {
|
||||
|
||||
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
|
||||
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
|
||||
#ifdef USE_ROCM
|
||||
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
|
||||
#endif
|
||||
|
||||
bool Context::checkCuBLASConfigDeterministic() {
|
||||
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
|
||||
@ -343,12 +340,6 @@ void Context::setImmediateMiopen(bool b) {
|
||||
}
|
||||
|
||||
bool Context::allowTF32CuBLAS() const {
|
||||
#ifdef USE_ROCM
|
||||
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
|
||||
if (allow_tf32 != true) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
|
||||
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
|
||||
TORCH_CHECK(
|
||||
@ -362,14 +353,6 @@ bool Context::allowTF32CuBLAS() const {
|
||||
}
|
||||
|
||||
void Context::setAllowTF32CuBLAS(bool b) {
|
||||
#ifdef USE_ROCM
|
||||
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
|
||||
if (allow_tf32 != true) {
|
||||
C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
|
||||
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
|
||||
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
|
||||
}
|
||||
@ -443,7 +426,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string&
|
||||
std::string msg;
|
||||
auto iterp = _fp32_precisions.find(backend);
|
||||
TORCH_CHECK(iterp != _fp32_precisions.end());
|
||||
for (auto p : iterp->second) {
|
||||
for (const auto& p : iterp->second) {
|
||||
msg += p;
|
||||
msg += " ";
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
@ -196,21 +195,6 @@ class GraphRegionTrackerTests(TestCase):
|
||||
)
|
||||
|
||||
def test_mismatched_global_state(self):
|
||||
@contextlib.contextmanager
|
||||
def _hip_allow_tf32():
|
||||
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
|
||||
# and only for MI300+
|
||||
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if hip_allow_tf32 is not None:
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
|
||||
else:
|
||||
del os.environ["HIPBLASLT_ALLOW_TF32"]
|
||||
|
||||
def inner_fn(x, y):
|
||||
x1 = x * 1
|
||||
y1 = y + 1
|
||||
@ -251,31 +235,29 @@ class GraphRegionTrackerTests(TestCase):
|
||||
def reset_default_dtype():
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
|
||||
with tf32_ctx():
|
||||
for ctx in [
|
||||
lambda: torch.set_grad_enabled(False),
|
||||
torch.autograd.grad_mode.inference_mode,
|
||||
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
|
||||
"This is not supported"
|
||||
),
|
||||
# lambda: torch.set_num_threads(2), : Unsupported
|
||||
(set_default_dtype_bfloat16, reset_default_dtype),
|
||||
(
|
||||
lambda: torch.use_deterministic_algorithms(True),
|
||||
lambda: torch.use_deterministic_algorithms(False),
|
||||
),
|
||||
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
|
||||
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
|
||||
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
|
||||
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
|
||||
create_toggle_fns("allow_tf32"),
|
||||
]:
|
||||
self.assertExpectedInline(
|
||||
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
|
||||
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
|
||||
for ctx in [
|
||||
lambda: torch.set_grad_enabled(False),
|
||||
torch.autograd.grad_mode.inference_mode,
|
||||
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
|
||||
"This is not supported"
|
||||
),
|
||||
# lambda: torch.set_num_threads(2), : Unsupported
|
||||
(set_default_dtype_bfloat16, reset_default_dtype),
|
||||
(
|
||||
lambda: torch.use_deterministic_algorithms(True),
|
||||
lambda: torch.use_deterministic_algorithms(False),
|
||||
),
|
||||
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
|
||||
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
|
||||
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
|
||||
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
|
||||
create_toggle_fns("allow_tf32"),
|
||||
]:
|
||||
self.assertExpectedInline(
|
||||
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
|
||||
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
|
||||
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
|
||||
)
|
||||
)
|
||||
|
||||
def test_mutation_tracking_simple(self):
|
||||
def fn(x, y, z):
|
||||
|
@ -8478,43 +8478,24 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
import contextlib
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _hip_allow_tf32():
|
||||
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
|
||||
# and only for MI300+
|
||||
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if hip_allow_tf32 is not None:
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
|
||||
else:
|
||||
del os.environ["HIPBLASLT_ALLOW_TF32"]
|
||||
|
||||
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
|
||||
with tf32_ctx():
|
||||
initial_state = read_state()
|
||||
y = torch.randn(10)
|
||||
try:
|
||||
for round in range(3):
|
||||
for i in range(len(initial_state)):
|
||||
new_state = [False] * len(initial_state)
|
||||
new_state[i] = True
|
||||
write_state(new_state)
|
||||
assert read_state() == new_state
|
||||
last_state.clear()
|
||||
fn(y)
|
||||
assert last_state == new_state
|
||||
if round == 0:
|
||||
assert cnt == i + 1
|
||||
else:
|
||||
assert cnt == len(initial_state)
|
||||
finally:
|
||||
write_state(initial_state)
|
||||
initial_state = read_state()
|
||||
y = torch.randn(10)
|
||||
try:
|
||||
for round in range(3):
|
||||
for i in range(len(initial_state)):
|
||||
new_state = [False] * len(initial_state)
|
||||
new_state[i] = True
|
||||
write_state(new_state)
|
||||
assert read_state() == new_state
|
||||
last_state.clear()
|
||||
fn(y)
|
||||
assert last_state == new_state
|
||||
if round == 0:
|
||||
assert cnt == i + 1
|
||||
else:
|
||||
assert cnt == len(initial_state)
|
||||
finally:
|
||||
write_state(initial_state)
|
||||
|
||||
def test_grad_state_mutated(self):
|
||||
prior = torch.is_grad_enabled()
|
||||
|
@ -43,9 +43,6 @@ if IS_WINDOWS and IS_CI:
|
||||
|
||||
|
||||
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
||||
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
|
||||
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
|
||||
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
|
||||
if torch.version.hip:
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
else:
|
||||
|
@ -109,9 +109,6 @@ class TestCaseBase(TestCase):
|
||||
if HAS_GPU:
|
||||
cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision()
|
||||
cls.prior_default_device = torch.get_default_device()
|
||||
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
|
||||
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
|
||||
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
|
||||
if torch.version.hip:
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
else:
|
||||
|
@ -759,53 +759,7 @@ print(t.is_pinned())
|
||||
|
||||
torch._C._cuda_clearCublasWorkspaces()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _hip_allow_tf32(self):
|
||||
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
|
||||
# and only for MI300+
|
||||
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if hip_allow_tf32 is not None:
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
|
||||
else:
|
||||
del os.environ["HIPBLASLT_ALLOW_TF32"]
|
||||
|
||||
@unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
|
||||
def test_hipblaslt_allow_tf32(self):
|
||||
tf32_ctx = self._hip_allow_tf32
|
||||
with tf32_ctx():
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = "0"
|
||||
# Save original value of allow_tf32
|
||||
orig = torch.backends.cuda.matmul.allow_tf32
|
||||
# If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp
|
||||
# then matmul.allow_tf32 will return False after this point even if
|
||||
# HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed.
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
|
||||
# Toggle torch.backends.cuda.matmul.allow_tf32 couple of times.
|
||||
torch.backends.cuda.matmul.allow_tf32 = not orig
|
||||
test1 = torch.backends.cuda.matmul.allow_tf32
|
||||
torch.backends.cuda.matmul.allow_tf32 = orig
|
||||
test2 = torch.backends.cuda.matmul.allow_tf32
|
||||
self.assertNotEqual(test1, test2)
|
||||
# Restore original value of allow_tf32
|
||||
torch.backends.cuda.matmul.allow_tf32 = orig
|
||||
|
||||
def test_cublas_allow_tf32_get_set(self):
|
||||
"""
|
||||
We only turn on TF32 for MI300 with a special env var. This is because TF32
|
||||
is only available in MI300+ and is in experimental mode (hipblaslt support
|
||||
is current WIP)
|
||||
"""
|
||||
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
|
||||
|
||||
with tf32_ctx():
|
||||
self._test_cublas_allow_tf32_get_set_inner()
|
||||
|
||||
def _test_cublas_allow_tf32_get_set_inner(self):
|
||||
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
|
||||
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
|
||||
)
|
||||
@ -820,12 +774,6 @@ print(t.is_pinned())
|
||||
torch.backends.cuda.matmul.allow_tf32 = orig
|
||||
|
||||
def test_float32_matmul_precision_get_set(self):
|
||||
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
|
||||
|
||||
with tf32_ctx():
|
||||
self._test_float32_matmul_precision_get_set_inner()
|
||||
|
||||
def _test_float32_matmul_precision_get_set_inner(self):
|
||||
orig = torch.get_float32_matmul_precision()
|
||||
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
|
||||
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
|
||||
|
@ -109,22 +109,6 @@ def get_tunableop_untuned_filename():
|
||||
return untuned_filename
|
||||
|
||||
class TestLinalg(TestCase):
|
||||
@contextlib.contextmanager
|
||||
def _hip_allow_tf32(self):
|
||||
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
|
||||
# and only for MI300+. Environment variable will be removed in the future.
|
||||
import os
|
||||
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if hip_allow_tf32 is not None:
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
|
||||
else:
|
||||
del os.environ["HIPBLASLT_ALLOW_TF32"]
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@ -5542,13 +5526,8 @@ class TestLinalg(TestCase):
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@dtypes(torch.float)
|
||||
def test_tf32_tunableop(self, device, dtype):
|
||||
# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
|
||||
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
|
||||
# and only for MI300+. Eventually this flag will go away.
|
||||
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
|
||||
|
||||
try:
|
||||
with self._tunableop_ctx(), tf32_ctx():
|
||||
with self._tunableop_ctx():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
|
||||
@ -5611,13 +5590,8 @@ class TestLinalg(TestCase):
|
||||
# This test is the offline version of test_tf32_tunableop
|
||||
import os
|
||||
|
||||
# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
|
||||
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
|
||||
# and only for MI300+. Eventually this flag will go away.
|
||||
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
|
||||
|
||||
try:
|
||||
with self._tunableop_ctx(), tf32_ctx():
|
||||
with self._tunableop_ctx():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
ordinal = torch.cuda.current_device()
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
|
@ -51,7 +51,6 @@ from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
tf32_on_and_off,
|
||||
tf32_enabled,
|
||||
ROCM_VERSION,
|
||||
)
|
||||
|
||||
if TEST_FAIRSEQ:
|
||||
@ -340,7 +339,7 @@ class TestTransformers(NNTestCase):
|
||||
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
|
||||
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
|
||||
|
||||
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
|
||||
@tf32_on_and_off(0.001)
|
||||
@parametrize("attn_mask_dim", [2, 3, None])
|
||||
@parametrize("key_padding_mask_dim", [2, None])
|
||||
@parametrize("mask_dtype", [torch.bool, torch.float32])
|
||||
@ -524,7 +523,7 @@ class TestTransformers(NNTestCase):
|
||||
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
|
||||
self.assertEqual(fastpath_output_expanded, slowpath_output)
|
||||
|
||||
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
|
||||
@tf32_on_and_off(0.001)
|
||||
@parametrize("with_no_grad", [True, False])
|
||||
@parametrize("training", [True, False])
|
||||
@parametrize("enable_nested_tensor", [False])
|
||||
@ -1110,7 +1109,7 @@ class TestTransformers(NNTestCase):
|
||||
return_all_hiddens=False,
|
||||
)[0]
|
||||
|
||||
@tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
|
||||
@tf32_on_and_off(0.003)
|
||||
@parametrize("input_dim,attn_mask_dim,is_causal",
|
||||
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
|
||||
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
|
||||
|
@ -591,7 +591,6 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
transA = layout[1] == "T"
|
||||
dtype = dtype_dict.get(data_type)
|
||||
if data_type == "tf32":
|
||||
# User must still set HIPBLASLT_ALLOW_TF32=1
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
@ -181,9 +181,6 @@ def tf32_off():
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tf32_on(self, tf32_precision=1e-5):
|
||||
if torch.version.hip:
|
||||
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
|
||||
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
||||
old_precision = self.precision
|
||||
try:
|
||||
@ -192,11 +189,6 @@ def tf32_on(self, tf32_precision=1e-5):
|
||||
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
|
||||
yield
|
||||
finally:
|
||||
if torch.version.hip:
|
||||
if hip_allow_tf32 is not None:
|
||||
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
|
||||
else:
|
||||
del os.environ["HIPBLASLT_ALLOW_TF32"]
|
||||
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
||||
self.precision = old_precision
|
||||
|
||||
@ -246,7 +238,7 @@ def tf32_enabled():
|
||||
# if device is specified, it will check if device is cuda
|
||||
# if dtype is specified, it will check if dtype is float32 or complex64
|
||||
# tf32 and fp32 are different only when all the three checks pass
|
||||
def tf32_on_and_off(tf32_precision=1e-5, only_if=True):
|
||||
def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True):
|
||||
def with_tf32_disabled(self, function_call):
|
||||
with tf32_off():
|
||||
function_call()
|
||||
|
Reference in New Issue
Block a user