diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index c22f07866f71..f1ec1a37bf82 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -1,6 +1,7 @@ #include -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) +#include #include #endif @@ -18,7 +19,13 @@ ThreadLocalState::ThreadLocalState() torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()), - saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {} + saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) { +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) + for(uint8_t i=0; i(i)); + } +#endif +} void ThreadLocalState::set_grad_mode(bool enabled) { autograd_tls_.set_grad_mode(enabled); @@ -54,6 +61,11 @@ void ThreadLocalState::setThreadLocalState( at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_); at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_); +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) + for(uint8_t i=0; i(i), state.autocast_dtypes_[i]); + } +#endif } } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 8419499c3a56..721ea9957513 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -78,6 +78,13 @@ class TORCH_API ThreadLocalState { // TLS for arbitrary python objects that is registered via hooks at::impl::ThreadLocalPythonObjects saved_objects_; +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ + !defined(BUILD_LITE_INTERPRETER) + // TLS for autocast dtypes + std::array + autocast_dtypes_; +#endif + friend class ThreadLocalStateGuard; }; diff --git a/test/test_autocast.py b/test/test_autocast.py index 8e702f92296c..0866c5c865bd 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -273,6 +273,32 @@ class TestAutocastGPU(TestCase): finally: torch._C._set_cached_tensors_enabled(False) + # index_put under AMP follows a cast policy called "promote", + # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230 + # That means: + # (1) double precision is ignored, + # (2) if any argument is float, then all arguments are promoted to float, + # (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype. + # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution, + # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward + # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied. + # For more info see https://github.com/pytorch/pytorch/issues/132715. + def test_autocast_prioritize(self): + device = "cuda" + dtype = torch.bfloat16 + + with torch.autocast(device_type=device, enabled=True, dtype=dtype): + t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True) + index = torch.randint( + low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device + ) + val = torch.randn(1, dtype=dtype, device=device) + + res = torch.index_put(t, [index], val) + + loss = res.mean() + loss.backward() + class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self):