mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make CUDAFuture handle any kind of device type (#57051)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57051 Make CUDAFuture autodetect the devicetype from its arguments (which thus change from DeviceIndices to full Devices). This in fact transforms CUDAFuture into a AnythingFuture, since it's not tied to CUDA in any way anymore. Having made it fully device-agnostic, we'll merge it into ivalue::Future in the next PR. ghstack-source-id: 127713134 (Note: this ignores all push blocking failures!) Test Plan: CI Reviewed By: mrshenli Differential Revision: D28032711 fbshipit-source-id: 8ba23b1b0d97f61db8693cd5f3c7bae7989a9bcd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
cf1595c48b
commit
71c2f88b90
@ -97,19 +97,48 @@ std::string formatSetOfDevices(
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const c10::impl::DeviceGuardImplInterface* getImplForDevices(
|
||||
const std::vector<c10::Device>& devices) {
|
||||
if (devices.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
c10::DeviceType deviceType = devices[0].type();
|
||||
for (size_t idx = 1; idx < devices.size(); idx++) {
|
||||
TORCH_CHECK_VALUE(
|
||||
devices[idx].type() == deviceType,
|
||||
"Expected all devices to be of the same type, but got a mismatch between ",
|
||||
devices[0],
|
||||
" and ",
|
||||
devices[idx]);
|
||||
}
|
||||
return c10::impl::getDeviceGuardImpl(deviceType);
|
||||
}
|
||||
|
||||
// We need devices to be sorted in order to use set_difference.
|
||||
std::vector<c10::DeviceIndex> sortDevices(
|
||||
std::vector<c10::DeviceIndex> devices) {
|
||||
std::sort(devices.begin(), devices.end());
|
||||
return devices;
|
||||
std::vector<c10::DeviceIndex> getSortedIndicesOfDevices(
|
||||
const c10::impl::DeviceGuardImplInterface* impl,
|
||||
const std::vector<c10::Device>& devices) {
|
||||
std::vector<bool> isDeviceUsed(impl->deviceCount(), false);
|
||||
for (const c10::Device& device : devices) {
|
||||
TORCH_CHECK_VALUE(
|
||||
device.has_index(), "Expected devices to have indices, got ", device);
|
||||
isDeviceUsed[device.index()] = true;
|
||||
}
|
||||
std::vector<c10::DeviceIndex> deviceIndices;
|
||||
for (c10::DeviceIndex idx = 0; idx < isDeviceUsed.size(); idx++) {
|
||||
if (isDeviceUsed[idx]) {
|
||||
deviceIndices.push_back(idx);
|
||||
}
|
||||
}
|
||||
return deviceIndices;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CUDAFuture::CUDAFuture(at::TypePtr type, std::vector<c10::DeviceIndex> devices)
|
||||
CUDAFuture::CUDAFuture(at::TypePtr type, std::vector<c10::Device> devices)
|
||||
: at::ivalue::Future(std::move(type)),
|
||||
impl_(c10::impl::getDeviceGuardImpl(c10::kCUDA)),
|
||||
devices_(sortDevices(std::move(devices))) {
|
||||
impl_(getImplForDevices(devices)),
|
||||
devices_(getSortedIndicesOfDevices(impl_, devices)) {
|
||||
// Use current device to initialize currentDevice_. This is necessary
|
||||
// because preMarkCompletedHook won't be called when the Future contains
|
||||
// an error. Uninitialized currentDevice_ could lead to crash when used
|
||||
@ -119,7 +148,12 @@ CUDAFuture::CUDAFuture(at::TypePtr type, std::vector<c10::DeviceIndex> devices)
|
||||
|
||||
c10::intrusive_ptr<ivalue::Future> CUDAFuture::createInstance(
|
||||
at::TypePtr type) {
|
||||
return c10::make_intrusive<CUDAFuture>(std::move(type), devices_);
|
||||
std::vector<c10::Device> devices;
|
||||
devices.reserve(devices_.size());
|
||||
for (const c10::DeviceIndex& index : devices_) {
|
||||
devices.emplace_back(impl_->type(), index);
|
||||
}
|
||||
return c10::make_intrusive<CUDAFuture>(std::move(type), std::move(devices));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -16,7 +16,7 @@ namespace cuda {
|
||||
|
||||
struct TORCH_CUDA_CPP_API CUDAFuture final : at::ivalue::Future {
|
||||
public:
|
||||
CUDAFuture(at::TypePtr type, std::vector<c10::DeviceIndex> devices);
|
||||
CUDAFuture(at::TypePtr type, std::vector<c10::Device> devices);
|
||||
|
||||
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override;
|
||||
|
||||
|
@ -145,7 +145,7 @@ class IODescriptor: ...
|
||||
class JITException: ...
|
||||
|
||||
class Future(object):
|
||||
def __init__(self, devices: List[_int]) -> None: ...
|
||||
def __init__(self, devices: List[device]) -> None: ...
|
||||
def done(self) -> _bool: ...
|
||||
def wait(self) -> Any: ...
|
||||
def add_done_callback(self, callback: Callable) -> None: ...
|
||||
|
@ -489,8 +489,13 @@ TensorPipeAgent::TensorPipeAgent(
|
||||
[](const std::vector<c10::DeviceIndex>& devices)
|
||||
-> std::shared_ptr<JitFuture> {
|
||||
if (!devices.empty()) {
|
||||
std::vector<c10::Device> fullDevices;
|
||||
fullDevices.reserve(devices.size());
|
||||
for (const c10::DeviceIndex index : devices) {
|
||||
fullDevices.emplace_back(c10::kCUDA, index);
|
||||
}
|
||||
return std::make_shared<at::cuda::CUDAFuture>(
|
||||
at::AnyClassType::get(), devices);
|
||||
at::AnyClassType::get(), std::move(fullDevices));
|
||||
} else {
|
||||
return std::make_shared<JitFuture>(at::AnyClassType::get());
|
||||
}
|
||||
|
@ -1212,18 +1212,28 @@ void initJITBindings(PyObject* module) {
|
||||
|
||||
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
|
||||
m, "Future")
|
||||
.def(py::init([](std::vector<c10::DeviceIndex> devices = {}) {
|
||||
.def(py::init([](const std::vector<py::object>& pyDevices = {}) {
|
||||
c10::intrusive_ptr<c10::ivalue::Future> fut;
|
||||
#ifdef USE_CUDA
|
||||
if (devices.empty()) {
|
||||
if (pyDevices.empty()) {
|
||||
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
|
||||
} else {
|
||||
std::vector<c10::Device> devices;
|
||||
devices.reserve(pyDevices.size());
|
||||
for (const py::object& pyDev : pyDevices) {
|
||||
TORCH_CHECK_TYPE(
|
||||
THPDevice_Check(pyDev.ptr()),
|
||||
"Expected torch.device, got ",
|
||||
py::repr(pyDev));
|
||||
auto device = reinterpret_cast<THPDevice*>(pyDev.ptr());
|
||||
devices.emplace_back(device->device);
|
||||
}
|
||||
fut = c10::make_intrusive<at::cuda::CUDAFuture>(
|
||||
PyObjectType::get(), std::move(devices));
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK_VALUE(
|
||||
devices.empty(),
|
||||
pyDevices.empty(),
|
||||
"Tried to instantiate a Future with some devices, but PyTorch was built without CUDA support");
|
||||
fut = c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get());
|
||||
#endif
|
||||
|
@ -31,16 +31,15 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
``torch.cuda.is_available()`` returns ``True``). This is needed to
|
||||
ensure proper CUDA stream synchronization. The child futures, returned
|
||||
by the ``then`` method, will inherit these devices.
|
||||
|
||||
Args:
|
||||
devices(``List[Union[int, str, torch.device]]``, optional): the set
|
||||
of devices on which tensors contained in this future's value are
|
||||
allowed to reside and on which callbacks are allowed to operate.
|
||||
"""
|
||||
if devices is None:
|
||||
devices = []
|
||||
device_indices = []
|
||||
for d in devices:
|
||||
d = torch.device(d)
|
||||
if d.type != "cuda":
|
||||
raise ValueError(f"Expected CUDA devices, got {d}")
|
||||
device_indices.append(d.index)
|
||||
super().__init__(device_indices)
|
||||
super().__init__([torch.device(d) for d in devices])
|
||||
|
||||
def done(self) -> bool:
|
||||
r"""
|
||||
|
@ -184,18 +184,6 @@ std::string getExceptionMsgFromExceptionPtr(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<c10::DeviceIndex> getIndicesOfDevices(
|
||||
const std::vector<c10::Device>& devices) {
|
||||
std::vector<c10::DeviceIndex> deviceIndices;
|
||||
deviceIndices.reserve(devices.size());
|
||||
for (const at::Device& device : devices) {
|
||||
TORCH_INTERNAL_ASSERT(device.is_cuda());
|
||||
deviceIndices.push_back(device.index());
|
||||
}
|
||||
return deviceIndices;
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
|
||||
@ -1132,7 +1120,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
|
||||
c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]);
|
||||
work->future_ = c10::make_intrusive<at::cuda::CUDAFuture>(
|
||||
c10::ListType::create(c10::TensorType::get()),
|
||||
getIndicesOfDevices(devices));
|
||||
devices);
|
||||
|
||||
// Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
|
||||
// future blocks the stream this callback runs on the corresponding
|
||||
@ -1228,7 +1216,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::pointToPoint(
|
||||
c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]);
|
||||
work->future_ = c10::make_intrusive<at::cuda::CUDAFuture>(
|
||||
c10::ListType::create(c10::TensorType::get()),
|
||||
getIndicesOfDevices(devices));
|
||||
devices);
|
||||
work->future_->markCompleted(at::IValue(*work->outputs_));
|
||||
}
|
||||
|
||||
|
@ -5852,7 +5852,9 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture):
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_device_not_cuda(self):
|
||||
with self.assertRaisesRegex(ValueError, "Expected CUDA devices, got "):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Expected devices to have indices, got cpu"
|
||||
):
|
||||
fut = Future(devices=["cpu"])
|
||||
|
||||
def _test_cuda_future_extraction(self, wrapper, unwrapper):
|
||||
|
Reference in New Issue
Block a user