mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA][AMP] Fix autocast_dtype (#133938)
Fixes #132715
The failure in #132715 is due to `autocast_dtype` being a thread-local variable. It causes inconsistencies between `get_autocast_dtype()` among different threads.
To be exact, what is happening in the following: The amp dtype is set to `bfloat16` on main thread. The `backward` call runs on a side thread, so `at::autocast::prioritize` fails because `lower_precision_fp` defaults to `float16`:
6f738d6434/aten/src/ATen/autocast_mode.h (L221-L225)
This PR makes `autocast_dtype` thread-global so it consistent among all threads of forward and backward passes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133938
Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
977a909250
commit
956da79bda
@ -1,6 +1,7 @@
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#endif
|
||||
|
||||
@ -18,7 +19,13 @@ ThreadLocalState::ThreadLocalState()
|
||||
torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
|
||||
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
|
||||
saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
|
||||
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {}
|
||||
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
|
||||
for(uint8_t i=0; i<autocast_dtypes_.size(); i++) {
|
||||
autocast_dtypes_[i] = at::autocast::get_autocast_dtype(static_cast<at::DeviceType>(i));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ThreadLocalState::set_grad_mode(bool enabled) {
|
||||
autograd_tls_.set_grad_mode(enabled);
|
||||
@ -54,6 +61,11 @@ void ThreadLocalState::setThreadLocalState(
|
||||
at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
|
||||
|
||||
at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
|
||||
for(uint8_t i=0; i<state.autocast_dtypes_.size(); i++) {
|
||||
at::autocast::set_autocast_dtype(static_cast<at::DeviceType>(i), state.autocast_dtypes_[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -78,6 +78,13 @@ class TORCH_API ThreadLocalState {
|
||||
// TLS for arbitrary python objects that is registered via hooks
|
||||
at::impl::ThreadLocalPythonObjects saved_objects_;
|
||||
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
|
||||
!defined(BUILD_LITE_INTERPRETER)
|
||||
// TLS for autocast dtypes
|
||||
std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
autocast_dtypes_;
|
||||
#endif
|
||||
|
||||
friend class ThreadLocalStateGuard;
|
||||
};
|
||||
|
||||
|
@ -273,6 +273,32 @@ class TestAutocastGPU(TestCase):
|
||||
finally:
|
||||
torch._C._set_cached_tensors_enabled(False)
|
||||
|
||||
# index_put under AMP follows a cast policy called "promote",
|
||||
# https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230
|
||||
# That means:
|
||||
# (1) double precision is ignored,
|
||||
# (2) if any argument is float, then all arguments are promoted to float,
|
||||
# (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype.
|
||||
# Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution,
|
||||
# and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward
|
||||
# pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied.
|
||||
# For more info see https://github.com/pytorch/pytorch/issues/132715.
|
||||
def test_autocast_prioritize(self):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
with torch.autocast(device_type=device, enabled=True, dtype=dtype):
|
||||
t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True)
|
||||
index = torch.randint(
|
||||
low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device
|
||||
)
|
||||
val = torch.randn(1, dtype=dtype, device=device)
|
||||
|
||||
res = torch.index_put(t, [index], val)
|
||||
|
||||
loss = res.mean()
|
||||
loss.backward()
|
||||
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
|
Reference in New Issue
Block a user