Compare commits

...

3 Commits

Author SHA1 Message Date
e38ac55aae cleanup and remove alert test 2025-09-10 22:46:12 +00:00
f20c28eba1 update 2025-09-02 17:48:42 +00:00
daac58237f check in 2025-08-28 23:43:31 +00:00
7 changed files with 11 additions and 164 deletions

View File

@ -279,45 +279,10 @@ bool Context::userEnabledOverrideableSDP() const {
return enabled_overrideable;
}
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
#ifdef USE_ROCM
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
#endif
bool Context::checkCuBLASConfigDeterministic() {
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
// is set to deterministic setting
if (hasCUDART()) {
const auto workspace_config = c10::utils::get_env(cublas_config_var_name);
return (workspace_config == cublas_deterministic_configs[0] || workspace_config == cublas_deterministic_configs[1]);
}
return true;
}
void Context::alertCuBLASConfigNotDeterministic() const {
static const bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
if (C10_LIKELY(!deterministicAlgorithms() || cublas_config_deterministic)) {
return;
}
auto msg = c10::str(
"Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ",
"`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ",
"it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this ",
"case, you must set an environment variable before running your PyTorch application: ",
cublas_config_var_name, "=", cublas_deterministic_configs[0], " or ",
cublas_config_var_name, "=", cublas_deterministic_configs[1], ". For more information, go to ",
"https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility"
);
if (deterministicAlgorithmsWarnOnly()) {
TORCH_WARN(msg);
} else {
TORCH_CHECK(false, msg);
}
}
bool Context::benchmarkCuDNN() const {
return benchmark_cudnn;
}

View File

@ -310,13 +310,7 @@ class TORCH_API Context {
//
// * Throw an error when `Context::deterministicAlgorithms()` is true. Most
// of the time, this should be accomplished by calling
// `at::globalContext().alertNotDeterminstic()`. However, if the
// nondeterministic behavior is caused by the CuBLAS workspace
// configuration in CUDA >= 10.2,
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
// called instead (in this case, a comment explaining why the operation is
// nondeterministic is not necessary). See below for details on these
// methods.
// `at::globalContext().alertNotDeterminstic().
//
// * Have an entry in the list of nondeterministic PyTorch operations in the
// docstring of `use_deterministic_algorithms()` in torch/__init__.py
@ -340,12 +334,6 @@ class TORCH_API Context {
// Throws an error if `Context::deterministicAlgorithms()` is true
static void alertNotDeterministic(std::string_view const& caller);
// Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
// >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
// ":4096:8". For more details:
// https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
void alertCuBLASConfigNotDeterministic() const;
void setFloat32MatmulPrecision(const std::string& s);
void setFloat32Precision(
const std::string& backend,
@ -429,7 +417,6 @@ class TORCH_API Context {
}
private:
static bool checkCuBLASConfigDeterministic();
std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;

View File

@ -436,7 +436,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented");
}
globalContext().alertCuBLASConfigNotDeterministic();
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -570,8 +569,6 @@ inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_D
template <>
void bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -583,8 +580,6 @@ void bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
template <>
void bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -596,8 +591,6 @@ void bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
template <>
void bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -611,8 +604,6 @@ void bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::co
template <>
void bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -626,8 +617,6 @@ void bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::com
template <typename C_Dtype>
inline void bgemm_internal_cublas_half_helper(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, C_Dtype)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -697,8 +686,6 @@ inline void bgemm_internal_cublas_half_helper(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYP
template <typename C_Dtype>
inline void bgemm_internal_cublas_bfloat16_helper(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, C_Dtype)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
BGEMM_CHECK_ARGVALUES(at::BFloat16);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
@ -1027,8 +1014,6 @@ inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dty
template <>
void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -1040,8 +1025,6 @@ void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
template <>
void gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -1053,8 +1036,6 @@ void gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
template <>
void gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -1068,8 +1049,6 @@ void gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::comp
template <>
void gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -1083,8 +1062,6 @@ void gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::compl
template <typename C_Dtype>
inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, C_Dtype)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -1191,7 +1168,6 @@ inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(
template <typename C_Dtype>
inline void gemm_internal_cublas_bfloat16_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, C_Dtype)) {
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
@ -2415,8 +2391,6 @@ void trsmBatched<c10::complex<double>>(
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
@ -2432,8 +2406,6 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
@ -2446,8 +2418,6 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
@ -2461,8 +2431,6 @@ void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);

View File

@ -125,10 +125,6 @@ deterministic implementation will be used::
[[ 0.1509, 1.8027],
[ 0.0333, -1.1444]]], device='cuda:0')
Furthermore, if you are using CUDA tensors, and your CUDA version is 10.2 or greater, you
should set the environment variable `CUBLAS_WORKSPACE_CONFIG` according to CUDA documentation:
`<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_
CUDA convolution determinism
----------------------------
While disabling CUDA convolution benchmarking (discussed above) ensures that

View File

@ -297,6 +297,16 @@ class TestMatmulCuda(TestCase):
# cross comparison
self.assertEqual(out1_gpu, out2_gpu[0])
@onlyCUDA
@skipIfRocm
@parametrize("shape", [2**i for i in range(5, 14)])
@dtypes(torch.float, torch.half, torch.bfloat16)
def test_cublas_deterministic(self, device, shape, dtype):
inp = torch.randn(shape, shape, device=device, dtype=dtype)
first = torch.matmul(inp, inp)
for _ in range(10):
self.assertEqual(first, torch.matmul(inp, inp), atol=0., rtol=0.)
def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist):
for a, b, gO, agrad, bgrad, out in zip(alist, blist, gOlist, agradlist, bgradlist, outlist):
a = a.clone().detach().requires_grad_()

View File

@ -1232,74 +1232,6 @@ class TestTorchDeviceType(TestCase):
_test_in_place_broadcastable(small2, small_expanded, large_expanded)
_test_in_place_broadcastable(small2, small, large)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA
@wrapDeterministicFlagAPITest
def test_cublas_config_nondeterministic_alert(self, device):
test_cases = [
# (function, (tensor sizes))
('mm', ((2, 2), (2, 2),)),
('mv', ((2, 2), (2,),)),
('bmm', ((1, 2, 2), (1, 2, 2),))]
test_configs = [
# (CuBLAS workspace config, is deterministic)
('garbage', False),
(None, False),
(':4096:8', True),
(':16:8', True)]
cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
is_cuda10_2_or_higher = (torch.version.cuda is not None)
def test_case_info(fn_name, config):
return f'function "{fn_name}" with config "{"" if config is None else config}"'
# Create processes to test each combination of test cases and config settings
for fn_name, arg_sizes in test_cases:
for config, is_config_deterministic in test_configs:
env = os.environ.copy()
if config is None:
if env.get(cublas_var_name) is not None:
del env[cublas_var_name]
else:
env[cublas_var_name] = config
should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic
script = f"""
import torch
torch.use_deterministic_algorithms(True)
fn = torch.{fn_name}
arg_sizes = {arg_sizes}
device = '{device}'
should_throw_error = {should_throw_error}
args = []
for arg_size in arg_sizes:
args.append(torch.randn(*arg_size, device=device))
try:
fn(*args)
except RuntimeError as e:
if not should_throw_error:
raise RuntimeError('Did not expect any error to be raised')
elif 'Deterministic behavior was enabled with either' not in str(e):
raise RuntimeError('Expected a CuBLAS nondeterministic error, but got a different error')
else:
if should_throw_error:
raise RuntimeError('Expected a CuBLAS nondeterministic error, but it was not raised')
"""
try:
subprocess.check_output(
[sys.executable, '-c', script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
env=env)
except subprocess.CalledProcessError as e:
self.fail(msg=(
f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n'
+ e.output.decode("utf-8")))
@onlyCPU
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
@dtypes(*get_all_qint_dtypes())

View File

@ -1408,17 +1408,6 @@ def use_deterministic_algorithms(
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
See the documentation for that attribute for more information.
A handful of CUDA operations are nondeterministic if the CUDA version is
10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more
details: `<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_
If one of these environment variable configurations is not set, a :class:`RuntimeError`
will be raised from these operations when called with CUDA tensors:
* :func:`torch.mm`
* :func:`torch.mv`
* :func:`torch.bmm`
Note that deterministic operations tend to have worse performance than
nondeterministic operations.