mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add Autocast Support for FakeTensors / use fake device dispatch keys (#82449)
From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. See: https://github.com/pytorch/pytorch/issues/81608 Pull Request resolved: https://github.com/pytorch/pytorch/pull/82449 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a64c981d09
commit
642aed8b99
@ -15,6 +15,7 @@
|
||||
#include <ATen/core/Vitals.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <libshm.h>
|
||||
@ -1259,6 +1260,26 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
return toString(x.key_set());
|
||||
});
|
||||
|
||||
py_module.def("_add_meta_to_tls_dispatch_include", []() {
|
||||
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
|
||||
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
|
||||
local_keyset.included_ = local_keyset.included_ | key_set;
|
||||
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
|
||||
});
|
||||
py_module.def("_remove_meta_from_tls_dispatch_include", []() {
|
||||
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
|
||||
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
|
||||
auto k = key_set.highestBackendKey();
|
||||
local_keyset.included_ = local_keyset.included_.remove_backend(k);
|
||||
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
|
||||
});
|
||||
|
||||
py_module.def("_dump_local_tls_set", []() {
|
||||
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
|
||||
std::cout << "Included: " << toString(local_keyset.included_) << "\n";
|
||||
std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
|
||||
});
|
||||
|
||||
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
|
||||
THPDefaultCPUGenerator =
|
||||
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
|
||||
|
Reference in New Issue
Block a user