Defer lazyInitCUDA() until needed (#11893)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11893

This is needed to run binaries compiled with CUDA support on on CPU-only machines.

Reviewed By: teng-li

Differential Revision: D9972872

fbshipit-source-id: 7e4107925b3cd4d2fcf84ae532e800ab65f4b563
This commit is contained in:
Pieter Noordhuis
2018-09-20 11:58:14 -07:00
committed by Facebook Github Bot
parent 9cd0ae5e2d
commit 24ec813967
2 changed files with 6 additions and 18 deletions

View File

@ -310,10 +310,6 @@ ProcessGroupGloo::ProcessGroupGloo(
for (size_t i = 0; i < threads_.size(); i++) {
threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this);
}
#ifdef USE_CUDA
thcState_ = ::at::globalContext().lazyInitCUDA();
#endif
}
ProcessGroupGloo::~ProcessGroupGloo() {
@ -605,14 +601,15 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::broadcast(
// In case of CUDA, ensure that operations that are queued after
// this collective wait for the collective to complete.
if (key.type->is_cuda()) {
synchronizeStreams(thcState_, entry);
auto thcState = ::at::globalContext().lazyInitCUDA();
synchronizeStreams(thcState, entry);
entry->run = [=]() mutable {
entry->algorithm->run();
for (size_t i = 0; i < tensors.size(); i++) {
// The THCStreamGuard is a RAII wrapper for temporarily
// overriding the current THCStream. This also sets the
// current device to the stream's device.
THCStreamGuard guard(thcState_, entry->streams[i]);
THCStreamGuard guard(thcState, entry->streams[i]);
tensors[i].copy_(entry->src[i]);
}
};
@ -655,14 +652,15 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allreduce(
// In case of CUDA, ensure that operations that are queued after
// this collective wait for the collective to complete.
if (key.type->is_cuda()) {
synchronizeStreams(thcState_, entry);
auto thcState = ::at::globalContext().lazyInitCUDA();
synchronizeStreams(thcState, entry);
entry->run = [=]() mutable {
entry->algorithm->run();
for (size_t i = 0; i < tensors.size(); i++) {
// The THCStreamGuard is a RAII wrapper for temporarily
// overriding the current THCStream. This also sets the
// current device to the stream's device.
THCStreamGuard guard(thcState_, entry->streams[i]);
THCStreamGuard guard(thcState, entry->streams[i]);
tensors[i].copy_(entry->src[i]);
}
};

View File

@ -24,11 +24,6 @@
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
#ifdef USE_CUDA
// Forward declaration
struct THCState;
#endif
namespace c10d {
// AlgorithmKey is a const identifier for a Gloo algorithm.
@ -389,11 +384,6 @@ class ProcessGroupGloo : public ProcessGroup {
std::mutex queueMutex_;
std::condition_variable queueProduceCV_;
std::condition_variable queueConsumeCV_;
#ifdef USE_CUDA
// Store copy of pointer to THCState retrieved from ::at::globalContext().
THCState* thcState_;
#endif
};
} // namespace c10d