[CUDA][AMP] Fix autocast_dtype (#133938)

Fixes #132715

The failure in #132715 is due to `autocast_dtype` being a thread-local variable. It causes inconsistencies between `get_autocast_dtype()` among different threads.

To be exact, what is happening in the following: The amp dtype is set to `bfloat16` on main thread. The `backward` call runs on a side thread, so `at::autocast::prioritize` fails because `lower_precision_fp` defaults to `float16`:
6f738d6434/aten/src/ATen/autocast_mode.h (L221-L225)

This PR makes `autocast_dtype` thread-global so it consistent among all threads of forward and backward passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133938
Approved by: https://github.com/soulitzer
This commit is contained in:
Aidyn-A
2024-09-05 00:07:31 +00:00
committed by PyTorch MergeBot
parent 977a909250
commit 956da79bda
3 changed files with 47 additions and 2 deletions

View File

@ -1,6 +1,7 @@
#include <ATen/ThreadLocalState.h>
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
#include <ATen/autocast_mode.h>
#include <ATen/core/grad_mode.h>
#endif
@ -18,7 +19,13 @@ ThreadLocalState::ThreadLocalState()
torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {}
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
for(uint8_t i=0; i<autocast_dtypes_.size(); i++) {
autocast_dtypes_[i] = at::autocast::get_autocast_dtype(static_cast<at::DeviceType>(i));
}
#endif
}
void ThreadLocalState::set_grad_mode(bool enabled) {
autograd_tls_.set_grad_mode(enabled);
@ -54,6 +61,11 @@ void ThreadLocalState::setThreadLocalState(
at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
for(uint8_t i=0; i<state.autocast_dtypes_.size(); i++) {
at::autocast::set_autocast_dtype(static_cast<at::DeviceType>(i), state.autocast_dtypes_[i]);
}
#endif
}
} // namespace at

View File

@ -78,6 +78,13 @@ class TORCH_API ThreadLocalState {
// TLS for arbitrary python objects that is registered via hooks
at::impl::ThreadLocalPythonObjects saved_objects_;
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
!defined(BUILD_LITE_INTERPRETER)
// TLS for autocast dtypes
std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
autocast_dtypes_;
#endif
friend class ThreadLocalStateGuard;
};

View File

@ -273,6 +273,32 @@ class TestAutocastGPU(TestCase):
finally:
torch._C._set_cached_tensors_enabled(False)
# index_put under AMP follows a cast policy called "promote",
# https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230
# That means:
# (1) double precision is ignored,
# (2) if any argument is float, then all arguments are promoted to float,
# (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype.
# Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution,
# and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward
# pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied.
# For more info see https://github.com/pytorch/pytorch/issues/132715.
def test_autocast_prioritize(self):
device = "cuda"
dtype = torch.bfloat16
with torch.autocast(device_type=device, enabled=True, dtype=dtype):
t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True)
index = torch.randint(
low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device
)
val = torch.randn(1, dtype=dtype, device=device)
res = torch.index_put(t, [index], val)
loss = res.mean()
loss.backward()
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):