Add Autocast Support for FakeTensors / use fake device dispatch keys (#82449)

From PR:
```
Note: [Fake Tensor Dispatch Keys]
In order to model the behavior of device-specific autocast
and autograd logic, we update the dispatch keys of FakeTensors
to reflect their fake device. This includes the BackendComponent
(DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent
related Autocast and Autograd keys. __torch__dispatch__ sits below
Autocast and Autograd, and is only invoked when we are at the
kernel for the BackendComponent. Then, we add Meta to the
thread-local dispatch include set to hit the meta kernel
instead of the kernel of the BackendComponent for the fake device.
```

Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge.

See: https://github.com/pytorch/pytorch/issues/81608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82449
Approved by: https://github.com/ezyang
This commit is contained in:
Elias Ellison
2022-08-01 18:02:10 +00:00
committed by PyTorch MergeBot
parent a64c981d09
commit 642aed8b99
13 changed files with 189 additions and 35 deletions

View File

@ -15,6 +15,7 @@
#include <ATen/core/Vitals.h>
#include <ATen/dlpack.h>
#include <ATen/native/ConvUtils.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/Logging.h>
#include <c10/util/irange.h>
#include <libshm.h>
@ -1259,6 +1260,26 @@ Call this whenever a new thread is created in order to propagate values from
return toString(x.key_set());
});
py_module.def("_add_meta_to_tls_dispatch_include", []() {
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
local_keyset.included_ = local_keyset.included_ | key_set;
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
});
py_module.def("_remove_meta_from_tls_dispatch_include", []() {
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
auto k = key_set.highestBackendKey();
local_keyset.included_ = local_keyset.included_.remove_backend(k);
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
});
py_module.def("_dump_local_tls_set", []() {
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
std::cout << "Included: " << toString(local_keyset.included_) << "\n";
std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
});
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
THPDefaultCPUGenerator =
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);