mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Defer lazyInitCUDA() until needed (#11893)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11893 This is needed to run binaries compiled with CUDA support on on CPU-only machines. Reviewed By: teng-li Differential Revision: D9972872 fbshipit-source-id: 7e4107925b3cd4d2fcf84ae532e800ab65f4b563
This commit is contained in:
committed by
Facebook Github Bot
parent
9cd0ae5e2d
commit
24ec813967
@ -310,10 +310,6 @@ ProcessGroupGloo::ProcessGroupGloo(
|
||||
for (size_t i = 0; i < threads_.size(); i++) {
|
||||
threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
thcState_ = ::at::globalContext().lazyInitCUDA();
|
||||
#endif
|
||||
}
|
||||
|
||||
ProcessGroupGloo::~ProcessGroupGloo() {
|
||||
@ -605,14 +601,15 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::broadcast(
|
||||
// In case of CUDA, ensure that operations that are queued after
|
||||
// this collective wait for the collective to complete.
|
||||
if (key.type->is_cuda()) {
|
||||
synchronizeStreams(thcState_, entry);
|
||||
auto thcState = ::at::globalContext().lazyInitCUDA();
|
||||
synchronizeStreams(thcState, entry);
|
||||
entry->run = [=]() mutable {
|
||||
entry->algorithm->run();
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
// The THCStreamGuard is a RAII wrapper for temporarily
|
||||
// overriding the current THCStream. This also sets the
|
||||
// current device to the stream's device.
|
||||
THCStreamGuard guard(thcState_, entry->streams[i]);
|
||||
THCStreamGuard guard(thcState, entry->streams[i]);
|
||||
tensors[i].copy_(entry->src[i]);
|
||||
}
|
||||
};
|
||||
@ -655,14 +652,15 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allreduce(
|
||||
// In case of CUDA, ensure that operations that are queued after
|
||||
// this collective wait for the collective to complete.
|
||||
if (key.type->is_cuda()) {
|
||||
synchronizeStreams(thcState_, entry);
|
||||
auto thcState = ::at::globalContext().lazyInitCUDA();
|
||||
synchronizeStreams(thcState, entry);
|
||||
entry->run = [=]() mutable {
|
||||
entry->algorithm->run();
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
// The THCStreamGuard is a RAII wrapper for temporarily
|
||||
// overriding the current THCStream. This also sets the
|
||||
// current device to the stream's device.
|
||||
THCStreamGuard guard(thcState_, entry->streams[i]);
|
||||
THCStreamGuard guard(thcState, entry->streams[i]);
|
||||
tensors[i].copy_(entry->src[i]);
|
||||
}
|
||||
};
|
||||
|
||||
@ -24,11 +24,6 @@
|
||||
#include <c10d/Types.hpp>
|
||||
#include <c10d/Utils.hpp>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// Forward declaration
|
||||
struct THCState;
|
||||
#endif
|
||||
|
||||
namespace c10d {
|
||||
|
||||
// AlgorithmKey is a const identifier for a Gloo algorithm.
|
||||
@ -389,11 +384,6 @@ class ProcessGroupGloo : public ProcessGroup {
|
||||
std::mutex queueMutex_;
|
||||
std::condition_variable queueProduceCV_;
|
||||
std::condition_variable queueConsumeCV_;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// Store copy of pointer to THCState retrieved from ::at::globalContext().
|
||||
THCState* thcState_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
Reference in New Issue
Block a user