Make CUDAFuture handle any kind of device type (#57051)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57051

Make CUDAFuture autodetect the devicetype from its arguments (which thus change from DeviceIndices to full Devices). This in fact transforms CUDAFuture into a AnythingFuture, since it's not tied to CUDA in any way anymore. Having made it fully device-agnostic, we'll merge it into ivalue::Future in the next PR.
ghstack-source-id: 127713134

(Note: this ignores all push blocking failures!)

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28032711

fbshipit-source-id: 8ba23b1b0d97f61db8693cd5f3c7bae7989a9bcd
This commit is contained in:
Luca Wehrstedt
2021-04-29 09:29:02 -07:00
committed by Facebook GitHub Bot
parent cf1595c48b
commit 71c2f88b90
8 changed files with 74 additions and 36 deletions

View File

@ -97,19 +97,48 @@ std::string formatSetOfDevices(
return oss.str();
}
const c10::impl::DeviceGuardImplInterface* getImplForDevices(
const std::vector<c10::Device>& devices) {
if (devices.empty()) {
return nullptr;
}
c10::DeviceType deviceType = devices[0].type();
for (size_t idx = 1; idx < devices.size(); idx++) {
TORCH_CHECK_VALUE(
devices[idx].type() == deviceType,
"Expected all devices to be of the same type, but got a mismatch between ",
devices[0],
" and ",
devices[idx]);
}
return c10::impl::getDeviceGuardImpl(deviceType);
}
// We need devices to be sorted in order to use set_difference.
std::vector<c10::DeviceIndex> sortDevices(
std::vector<c10::DeviceIndex> devices) {
std::sort(devices.begin(), devices.end());
return devices;
std::vector<c10::DeviceIndex> getSortedIndicesOfDevices(
const c10::impl::DeviceGuardImplInterface* impl,
const std::vector<c10::Device>& devices) {
std::vector<bool> isDeviceUsed(impl->deviceCount(), false);
for (const c10::Device& device : devices) {
TORCH_CHECK_VALUE(
device.has_index(), "Expected devices to have indices, got ", device);
isDeviceUsed[device.index()] = true;
}
std::vector<c10::DeviceIndex> deviceIndices;
for (c10::DeviceIndex idx = 0; idx < isDeviceUsed.size(); idx++) {
if (isDeviceUsed[idx]) {
deviceIndices.push_back(idx);
}
}
return deviceIndices;
}
} // namespace
CUDAFuture::CUDAFuture(at::TypePtr type, std::vector<c10::DeviceIndex> devices)
CUDAFuture::CUDAFuture(at::TypePtr type, std::vector<c10::Device> devices)
: at::ivalue::Future(std::move(type)),
impl_(c10::impl::getDeviceGuardImpl(c10::kCUDA)),
devices_(sortDevices(std::move(devices))) {
impl_(getImplForDevices(devices)),
devices_(getSortedIndicesOfDevices(impl_, devices)) {
// Use current device to initialize currentDevice_. This is necessary
// because preMarkCompletedHook won't be called when the Future contains
// an error. Uninitialized currentDevice_ could lead to crash when used
@ -119,7 +148,12 @@ CUDAFuture::CUDAFuture(at::TypePtr type, std::vector<c10::DeviceIndex> devices)
c10::intrusive_ptr<ivalue::Future> CUDAFuture::createInstance(
at::TypePtr type) {
return c10::make_intrusive<CUDAFuture>(std::move(type), devices_);
std::vector<c10::Device> devices;
devices.reserve(devices_.size());
for (const c10::DeviceIndex& index : devices_) {
devices.emplace_back(impl_->type(), index);
}
return c10::make_intrusive<CUDAFuture>(std::move(type), std::move(devices));
}
/**

View File

@ -16,7 +16,7 @@ namespace cuda {
struct TORCH_CUDA_CPP_API CUDAFuture final : at::ivalue::Future {
public:
CUDAFuture(at::TypePtr type, std::vector<c10::DeviceIndex> devices);
CUDAFuture(at::TypePtr type, std::vector<c10::Device> devices);
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override;

View File

@ -145,7 +145,7 @@ class IODescriptor: ...
class JITException: ...
class Future(object):
def __init__(self, devices: List[_int]) -> None: ...
def __init__(self, devices: List[device]) -> None: ...
def done(self) -> _bool: ...
def wait(self) -> Any: ...
def add_done_callback(self, callback: Callable) -> None: ...

View File

@ -489,8 +489,13 @@ TensorPipeAgent::TensorPipeAgent(
[](const std::vector<c10::DeviceIndex>& devices)
-> std::shared_ptr<JitFuture> {
if (!devices.empty()) {
std::vector<c10::Device> fullDevices;
fullDevices.reserve(devices.size());
for (const c10::DeviceIndex index : devices) {
fullDevices.emplace_back(c10::kCUDA, index);
}
return std::make_shared<at::cuda::CUDAFuture>(
at::AnyClassType::get(), devices);
at::AnyClassType::get(), std::move(fullDevices));
} else {
return std::make_shared<JitFuture>(at::AnyClassType::get());
}

View File

@ -1212,18 +1212,28 @@ void initJITBindings(PyObject* module) {
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
m, "Future")
.def(py::init([](std::vector<c10::DeviceIndex> devices = {}) {
.def(py::init([](const std::vector<py::object>& pyDevices = {}) {
c10::intrusive_ptr<c10::ivalue::Future> fut;
#ifdef USE_CUDA
if (devices.empty()) {
if (pyDevices.empty()) {
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
} else {
std::vector<c10::Device> devices;
devices.reserve(pyDevices.size());
for (const py::object& pyDev : pyDevices) {
TORCH_CHECK_TYPE(
THPDevice_Check(pyDev.ptr()),
"Expected torch.device, got ",
py::repr(pyDev));
auto device = reinterpret_cast<THPDevice*>(pyDev.ptr());
devices.emplace_back(device->device);
}
fut = c10::make_intrusive<at::cuda::CUDAFuture>(
PyObjectType::get(), std::move(devices));
}
#else
TORCH_CHECK_VALUE(
devices.empty(),
pyDevices.empty(),
"Tried to instantiate a Future with some devices, but PyTorch was built without CUDA support");
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
#endif

View File

@ -31,16 +31,15 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
``torch.cuda.is_available()`` returns ``True``). This is needed to
ensure proper CUDA stream synchronization. The child futures, returned
by the ``then`` method, will inherit these devices.
Args:
devices(``List[Union[int, str, torch.device]]``, optional): the set
of devices on which tensors contained in this future's value are
allowed to reside and on which callbacks are allowed to operate.
"""
if devices is None:
devices = []
device_indices = []
for d in devices:
d = torch.device(d)
if d.type != "cuda":
raise ValueError(f"Expected CUDA devices, got {d}")
device_indices.append(d.index)
super().__init__(device_indices)
super().__init__([torch.device(d) for d in devices])
def done(self) -> bool:
r"""

View File

@ -184,18 +184,6 @@ std::string getExceptionMsgFromExceptionPtr(
}
}
std::vector<c10::DeviceIndex> getIndicesOfDevices(
const std::vector<c10::Device>& devices) {
std::vector<c10::DeviceIndex> deviceIndices;
deviceIndices.reserve(devices.size());
for (const at::Device& device : devices) {
TORCH_INTERNAL_ASSERT(device.is_cuda());
deviceIndices.push_back(device.index());
}
return deviceIndices;
}
} // namespace
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
@ -1132,7 +1120,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]);
work->future_ = c10::make_intrusive<at::cuda::CUDAFuture>(
c10::ListType::create(c10::TensorType::get()),
getIndicesOfDevices(devices));
devices);
// Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
// future blocks the stream this callback runs on the corresponding
@ -1228,7 +1216,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::pointToPoint(
c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]);
work->future_ = c10::make_intrusive<at::cuda::CUDAFuture>(
c10::ListType::create(c10::TensorType::get()),
getIndicesOfDevices(devices));
devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
}

View File

@ -5852,7 +5852,9 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture):
@skip_if_lt_x_gpu(1)
def test_cuda_future_device_not_cuda(self):
with self.assertRaisesRegex(ValueError, "Expected CUDA devices, got "):
with self.assertRaisesRegex(
ValueError, "Expected devices to have indices, got cpu"
):
fut = Future(devices=["cpu"])
def _test_cuda_future_extraction(self, wrapper, unwrapper):