Merge CUDAFuture into ivalue::Future (#57052)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57052

This PR caps a stack whose goal was to merge CUDAFuture into ivalue::Future. CUDAFuture used to be a subclass of ivalue::Future, which was already pretty good, but it meant that in several places we needed `#ifdef`s or registries in order to create the right type of class, which was annoying. We've made CUDAFuture device-agnostic, by using generic helpers, so that it doesn't depend on CUDA. Now all its code can be inserted into ivalue::Future.

This PR does this very naively, by copy-pasting CUDAFuture's code into the (previously empty) virtual methods of ivalue::Future. This helps ensure the correctness of this PR, as it's straightforward to see it behaves exactly like before. However we probably want to polish it a bit later to iron out so wrinkles.
ghstack-source-id: 127713138

(Note: this ignores all push blocking failures!)

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28036829

fbshipit-source-id: 3e5b16402f5dc245c1fcb9d7bf06db64dcb0d2a3
This commit is contained in:
Luca Wehrstedt
2021-04-29 09:29:02 -07:00
committed by Facebook GitHub Bot
parent 71c2f88b90
commit 311ad5e3af
14 changed files with 290 additions and 452 deletions

View File

@ -97,9 +97,6 @@
#include <caffe2/serialize/inline_container.h>
#include <ATen/core/function_schema.h>
#ifdef USE_CUDA
#include <ATen/cuda/CUDAFuture.h>
#endif
#include <pybind11/functional.h>
#include <pybind11/iostream.h>
@ -1213,31 +1210,19 @@ void initJITBindings(PyObject* module) {
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
m, "Future")
.def(py::init([](const std::vector<py::object>& pyDevices = {}) {
c10::intrusive_ptr<c10::ivalue::Future> fut;
#ifdef USE_CUDA
if (pyDevices.empty()) {
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
} else {
std::vector<c10::Device> devices;
devices.reserve(pyDevices.size());
for (const py::object& pyDev : pyDevices) {
TORCH_CHECK_TYPE(
THPDevice_Check(pyDev.ptr()),
"Expected torch.device, got ",
py::repr(pyDev));
auto device = reinterpret_cast<THPDevice*>(pyDev.ptr());
devices.emplace_back(device->device);
}
fut = c10::make_intrusive<at::cuda::CUDAFuture>(
PyObjectType::get(), std::move(devices));
std::vector<c10::Device> devices;
devices.reserve(pyDevices.size());
for (const py::object& pyDev : pyDevices) {
TORCH_CHECK_TYPE(
THPDevice_Check(pyDev.ptr()),
"Expected torch.device, got ",
py::repr(pyDev));
auto device = reinterpret_cast<THPDevice*>(pyDev.ptr());
devices.emplace_back(device->device);
}
#else
TORCH_CHECK_VALUE(
pyDevices.empty(),
"Tried to instantiate a Future with some devices, but PyTorch was built without CUDA support");
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
#endif
return std::make_shared<PythonFutureWrapper>(std::move(fut));
return std::make_shared<PythonFutureWrapper>(
c10::make_intrusive<c10::ivalue::Future>(
PyObjectType::get(), std::move(devices)));
}))
.def(
"done",