mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Merge CUDAFuture into ivalue::Future (#57052)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57052 This PR caps a stack whose goal was to merge CUDAFuture into ivalue::Future. CUDAFuture used to be a subclass of ivalue::Future, which was already pretty good, but it meant that in several places we needed `#ifdef`s or registries in order to create the right type of class, which was annoying. We've made CUDAFuture device-agnostic, by using generic helpers, so that it doesn't depend on CUDA. Now all its code can be inserted into ivalue::Future. This PR does this very naively, by copy-pasting CUDAFuture's code into the (previously empty) virtual methods of ivalue::Future. This helps ensure the correctness of this PR, as it's straightforward to see it behaves exactly like before. However we probably want to polish it a bit later to iron out so wrinkles. ghstack-source-id: 127713138 (Note: this ignores all push blocking failures!) Test Plan: CI Reviewed By: mrshenli Differential Revision: D28036829 fbshipit-source-id: 3e5b16402f5dc245c1fcb9d7bf06db64dcb0d2a3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
71c2f88b90
commit
311ad5e3af
@ -97,9 +97,6 @@
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
#ifdef USE_CUDA
|
||||
#include <ATen/cuda/CUDAFuture.h>
|
||||
#endif
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/iostream.h>
|
||||
@ -1213,31 +1210,19 @@ void initJITBindings(PyObject* module) {
|
||||
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
|
||||
m, "Future")
|
||||
.def(py::init([](const std::vector<py::object>& pyDevices = {}) {
|
||||
c10::intrusive_ptr<c10::ivalue::Future> fut;
|
||||
#ifdef USE_CUDA
|
||||
if (pyDevices.empty()) {
|
||||
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
|
||||
} else {
|
||||
std::vector<c10::Device> devices;
|
||||
devices.reserve(pyDevices.size());
|
||||
for (const py::object& pyDev : pyDevices) {
|
||||
TORCH_CHECK_TYPE(
|
||||
THPDevice_Check(pyDev.ptr()),
|
||||
"Expected torch.device, got ",
|
||||
py::repr(pyDev));
|
||||
auto device = reinterpret_cast<THPDevice*>(pyDev.ptr());
|
||||
devices.emplace_back(device->device);
|
||||
}
|
||||
fut = c10::make_intrusive<at::cuda::CUDAFuture>(
|
||||
PyObjectType::get(), std::move(devices));
|
||||
std::vector<c10::Device> devices;
|
||||
devices.reserve(pyDevices.size());
|
||||
for (const py::object& pyDev : pyDevices) {
|
||||
TORCH_CHECK_TYPE(
|
||||
THPDevice_Check(pyDev.ptr()),
|
||||
"Expected torch.device, got ",
|
||||
py::repr(pyDev));
|
||||
auto device = reinterpret_cast<THPDevice*>(pyDev.ptr());
|
||||
devices.emplace_back(device->device);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK_VALUE(
|
||||
pyDevices.empty(),
|
||||
"Tried to instantiate a Future with some devices, but PyTorch was built without CUDA support");
|
||||
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
|
||||
#endif
|
||||
return std::make_shared<PythonFutureWrapper>(std::move(fut));
|
||||
return std::make_shared<PythonFutureWrapper>(
|
||||
c10::make_intrusive<c10::ivalue::Future>(
|
||||
PyObjectType::get(), std::move(devices)));
|
||||
}))
|
||||
.def(
|
||||
"done",
|
||||
|
||||
Reference in New Issue
Block a user