Revert "Add Accelerator device and shell hooks (#119329)"

This reverts commit 4b9568a360c4a90220e78e43435be8c56bc33fb2.

Reverted https://github.com/pytorch/pytorch/pull/119329 on behalf of https://github.com/huydhn due to Breaks internal build and requires OSS file update to fix it ([comment](https://github.com/pytorch/pytorch/pull/119329#issuecomment-1940278598))
This commit is contained in:
PyTorch MergeBot
2024-02-13 02:23:45 +00:00
parent 7d4b666870
commit 214f06ae3a
16 changed files with 103 additions and 185 deletions

View File

@ -1,13 +1,11 @@
#pragma once #pragma once
#include <ATen/CPUGeneratorImpl.h> #include <ATen/CPUGeneratorImpl.h>
#include <ATen/DeviceAccelerator.h>
#include <ATen/LinalgBackend.h> #include <ATen/LinalgBackend.h>
#include <ATen/core/ATenGeneral.h> #include <ATen/core/ATenGeneral.h>
#include <ATen/core/DeprecatedTypeProperties.h> #include <ATen/core/DeprecatedTypeProperties.h>
#include <ATen/core/Generator.h> #include <ATen/core/Generator.h>
#include <ATen/core/LegacyTypeDispatch.h> #include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <ATen/detail/CUDAHooksInterface.h> #include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h> #include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/IPUHooksInterface.h> #include <ATen/detail/IPUHooksInterface.h>
@ -58,22 +56,6 @@ class TORCH_API Context {
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
} }
} }
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
c10::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
c10::DeviceType device_type = opt_device_type.has_value()
? opt_device_type.value()
: at::getAccelerator(true).value();
if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks();
} else if (device_type == at::kMPS) {
return at::detail::getMPSHooks();
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks();
} else {
AT_ERROR(
c10::DeviceTypeName(device_type), " device type not an accelerator.");
}
}
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) { Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
initCUDAIfNeeded(device_type); initCUDAIfNeeded(device_type);
initHIPIfNeeded(device_type); initHIPIfNeeded(device_type);

View File

@ -1,31 +0,0 @@
#include <ATen/DeviceAccelerator.h>
#include <ATen/Context.h>
namespace at {
C10_API std::optional<DeviceType> getAccelerator(bool checked) {
#define CHECK_NO_CUDA \
TORCH_CHECK(!at::hasCUDA(), "Cannot have both CUDA and PrivateUse1");
#define CHECK_NO_PU1 \
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
if (is_privateuse1_backend_registered()) {
// We explicitly allow PrivateUse1 and another device at the same time
// as we use this for testing.
// Whenever a PrivateUse1 device is registered, use it first.
return kPrivateUse1;
} else if (at::hasCUDA()) {
CHECK_NO_PU1
return kCUDA;
} else {
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
return std::nullopt;
}
#undef CHECK_NO_CUDA
#undef CHECK_NO_PU1
}
} // namespace at

View File

@ -1,27 +0,0 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/macros/Macros.h>
#include <ATen/detail/MTIAHooksInterface.h>
#include <optional>
// This file defines the top level Accelerator concept for PyTorch.
// A device is an accelerator per the definition here if:
// - It is mutually exclusive with all other accelerators
// - It performs asynchronous compute via a Stream/Event system
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
//
// As of today, accelerator devices are (in no particular order):
// CUDA, MTIA, PrivateUse1
// We want to add once all the proper APIs are supported and tested:
// HIP, MPS, XPU
namespace at {
// Ensures that only one accelerator is available (at
// compile time if possible) and return it.
// When checked is true, the returned optional always has a value.
TORCH_API std::optional<DeviceType> getAccelerator(bool checked = false);
} // namespace at

View File

@ -1,21 +0,0 @@
#pragma once
#include <c10/core/Device.h>
namespace at {
// AcceleratorHooksInterface is a shared interface provided by all
// accelerators to allow generic code.
// This inferface is hook-based as it corresponds to all the functions
// that are going to be called in a generic way from the CPU code.
struct TORCH_API AcceleratorHooksInterface {
// This should never actually be implemented, but it is used to
// squelch -Werror=non-virtual-dtor
virtual ~AcceleratorHooksInterface() = default;
// Whether the device at device_index is fully initialized or not.
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
};
} // namespace at

View File

@ -4,8 +4,6 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Registry.h> #include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
// Forward-declares at::Generator and at::cuda::NVRTC // Forward-declares at::Generator and at::cuda::NVRTC
namespace at { namespace at {
struct Generator; struct Generator;
@ -59,7 +57,7 @@ constexpr const char* CUDA_HELP =
// TODO: Consider putting the stub definitions in another class, so that one // TODO: Consider putting the stub definitions in another class, so that one
// never forgets to implement each virtual function in the real implementation // never forgets to implement each virtual function in the real implementation
// in CUDAHooks. This probably doesn't buy us much though. // in CUDAHooks. This probably doesn't buy us much though.
struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { struct TORCH_API CUDAHooksInterface {
// This should never actually be implemented, but it is used to // This should never actually be implemented, but it is used to
// squelch -Werror=non-virtual-dtor // squelch -Werror=non-virtual-dtor
virtual ~CUDAHooksInterface() = default; virtual ~CUDAHooksInterface() = default;
@ -109,7 +107,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
} }
virtual bool hasPrimaryContext(DeviceIndex device_index) const override { virtual bool hasPrimaryContext(DeviceIndex device_index) const {
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP); TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
} }

View File

@ -4,7 +4,6 @@
#include <c10/core/Allocator.h> #include <c10/core/Allocator.h>
#include <ATen/core/Generator.h> #include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Registry.h> #include <c10/util/Registry.h>
@ -12,7 +11,7 @@
namespace at { namespace at {
struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { struct TORCH_API MPSHooksInterface {
// this fails the implementation if MPSHooks functions are called, but // this fails the implementation if MPSHooks functions are called, but
// MPS backend is not present. // MPS backend is not present.
#define FAIL_MPSHOOKS_FUNC(func) \ #define FAIL_MPSHOOKS_FUNC(func) \
@ -87,9 +86,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const { virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
FAIL_MPSHOOKS_FUNC(__func__); FAIL_MPSHOOKS_FUNC(__func__);
} }
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
FAIL_MPSHOOKS_FUNC(__func__);
}
#undef FAIL_MPSHOOKS_FUNC #undef FAIL_MPSHOOKS_FUNC
}; };

View File

@ -4,8 +4,6 @@
#include <c10/util/Registry.h> #include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <string> #include <string>
namespace at { namespace at {
@ -19,7 +17,7 @@ constexpr const char* MTIA_HELP =
"this error has occurred because you are trying " "this error has occurred because you are trying "
"to use some MTIA's functionality without MTIA extension included."; "to use some MTIA's functionality without MTIA extension included.";
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { struct TORCH_API MTIAHooksInterface {
virtual ~MTIAHooksInterface() = default; virtual ~MTIAHooksInterface() = default;
virtual void initMTIA() const { virtual void initMTIA() const {
@ -39,14 +37,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
"Cannot query detailed MTIA version without MTIA Extension for PyTorch.", "Cannot query detailed MTIA version without MTIA Extension for PyTorch.",
MTIA_HELP); MTIA_HELP);
} }
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK(
false,
"Cannot check MTIA primary context without MTIA Extension for PyTorch.",
MTIA_HELP);
}
}; };
struct TORCH_API MTIAHooksArgs {}; struct TORCH_API MTIAHooksArgs {};

View File

@ -22,15 +22,4 @@ TORCH_API bool isPrivateUse1HooksRegistered() {
return privateuse1_hooks != nullptr; return privateuse1_hooks != nullptr;
} }
namespace detail {
TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks() {
TORCH_CHECK(
privateuse1_hooks != nullptr,
"Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.");
return *privateuse1_hooks;
} }
} // namespace detail
} // namespace at

View File

@ -1,14 +1,13 @@
#pragma once #pragma once
#include <ATen/core/Generator.h> #include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h> #include <c10/core/Allocator.h>
#include <c10/core/Device.h> #include <c10/core/Device.h>
#include <c10/core/Storage.h> #include <c10/core/Storage.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
namespace at { namespace at {
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { struct TORCH_API PrivateUse1HooksInterface {
virtual ~PrivateUse1HooksInterface() = default; virtual ~PrivateUse1HooksInterface() = default;
virtual const at::Generator& getDefaultGenerator( virtual const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) { c10::DeviceIndex device_index) {
@ -29,7 +28,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`."); "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
} }
virtual bool hasPrimaryContext(DeviceIndex device_index) const override { virtual bool hasPrimaryContext(DeviceIndex device_index) const {
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`."); "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
@ -52,10 +51,4 @@ TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface();
TORCH_API bool isPrivateUse1HooksRegistered(); TORCH_API bool isPrivateUse1HooksRegistered();
namespace detail {
TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
} // namespace detail
} // namespace at } // namespace at

View File

@ -46,12 +46,6 @@ struct MPSHooks : public at::MPSHooksInterface {
void synchronizeEvent(uint32_t event_id) const override; void synchronizeEvent(uint32_t event_id) const override;
bool queryEvent(uint32_t event_id) const override; bool queryEvent(uint32_t event_id) const override;
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override; double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
// Compatibility with Accelerator API
bool hasPrimaryContext(DeviceIndex device_index) const override {
// When MPS is available, it is always in use for the one device.
return true;
}
}; };
} // namespace at::mps } // namespace at::mps

View File

@ -970,7 +970,6 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/AccumulateType.cpp", "aten/src/ATen/AccumulateType.cpp",
"aten/src/ATen/LegacyBatchedTensorImpl.cpp", "aten/src/ATen/LegacyBatchedTensorImpl.cpp",
"aten/src/ATen/CPUGeneratorImpl.cpp", "aten/src/ATen/CPUGeneratorImpl.cpp",
"aten/src/ATen/DeviceAccelerator.cpp",
"aten/src/ATen/Context.cpp", "aten/src/ATen/Context.cpp",
"aten/src/ATen/DLConvertor.cpp", "aten/src/ATen/DLConvertor.cpp",
"aten/src/ATen/EmptyTensor.cpp", "aten/src/ATen/EmptyTensor.cpp",

View File

@ -8,7 +8,6 @@
#include <torch/csrc/autograd/variable.h> #include <torch/csrc/autograd/variable.h>
#include <torch/csrc/dynamo/compiled_autograd.h> #include <torch/csrc/dynamo/compiled_autograd.h>
#include <ATen/DeviceAccelerator.h>
#include <ATen/DeviceGuard.h> #include <ATen/DeviceGuard.h>
#include <ATen/ExpandUtils.h> #include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h> #include <ATen/Parallel.h>
@ -47,7 +46,8 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
namespace torch::autograd { namespace torch {
namespace autograd {
namespace { namespace {
static bool in_bad_autograd_fork = static bool in_bad_autograd_fork =
@ -991,7 +991,10 @@ void Engine::evaluate_function(
// ensure they're safe to consume in the context of the present // ensure they're safe to consume in the context of the present
// func's stream (if applicable). So we guard onto that stream // func's stream (if applicable). So we guard onto that stream
// before working with the grads in any capacity. // before working with the grads in any capacity.
auto opt_parent_stream = (*func).stream(); auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
if (!opt_parent_stream.has_value()) {
opt_parent_stream = (*func).stream(c10::DeviceType::PrivateUse1);
}
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream}; c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// If exec_info_ is not empty, we have to instrument the execution // If exec_info_ is not empty, we have to instrument the execution
@ -1009,7 +1012,10 @@ void Engine::evaluate_function(
*func, InputBuffer::variables(std::move(inputs))); *func, InputBuffer::variables(std::move(inputs)));
} }
if (auto* capture_vec = fn_info.captures_.get()) { if (auto* capture_vec = fn_info.captures_.get()) {
auto opt_parent_stream = (*func).stream(); auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
if (!opt_parent_stream.has_value()) {
opt_parent_stream = (*func).stream(c10::DeviceType::PrivateUse1);
}
// Lock mutex for writing to graph_task->captured_vars_. // Lock mutex for writing to graph_task->captured_vars_.
std::lock_guard<std::mutex> lock(graph_task->mutex_); std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto& capture : *capture_vec) { for (const auto& capture : *capture_vec) {
@ -1101,7 +1107,10 @@ void Engine::evaluate_function(
InputBuffer input_buffer(next.function->num_inputs()); InputBuffer input_buffer(next.function->num_inputs());
// Accumulates into buffer // Accumulates into buffer
auto opt_next_stream = next.function->stream(); auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
if (!opt_next_stream.has_value()) {
opt_next_stream = next.function->stream(c10::DeviceType::PrivateUse1);
}
input_buffer.add( input_buffer.add(
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
@ -1117,7 +1126,10 @@ void Engine::evaluate_function(
auto& input_buffer = not_ready_it->second; auto& input_buffer = not_ready_it->second;
// Accumulates into buffer // Accumulates into buffer
auto opt_next_stream = next.function->stream(); auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
if (!opt_next_stream.has_value()) {
opt_next_stream = next.function->stream(c10::DeviceType::PrivateUse1);
}
input_buffer.add( input_buffer.add(
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
if (is_ready) { if (is_ready) {
@ -1149,7 +1161,10 @@ auto Engine::compute_dependencies(
uint64_t min_topo_nr) -> void { uint64_t min_topo_nr) -> void {
// Computes the number of dependencies for each function which requires grad // Computes the number of dependencies for each function which requires grad
std::vector<Node*> queue{root}; std::vector<Node*> queue{root};
bool will_use_accelerator = false; bool might_use_cuda = at::globalContext().hasCUDA();
bool might_use_privateuse1 = at::isPrivateUse1HooksRegistered();
bool will_use_cuda = false;
bool will_use_privateuse1 = false;
// Queue contains all nodes that will start propagating gradients. // Queue contains all nodes that will start propagating gradients.
// We no longer have to expand functions that don't require grad. // We no longer have to expand functions that don't require grad.
@ -1160,8 +1175,12 @@ auto Engine::compute_dependencies(
if (fn->topological_nr() < min_topo_nr) { if (fn->topological_nr() < min_topo_nr) {
continue; continue;
} }
if (!will_use_accelerator) { if (might_use_cuda && !will_use_cuda) {
will_use_accelerator = fn->stream().has_value(); will_use_cuda = fn->stream(c10::DeviceType::CUDA).has_value();
}
if (might_use_privateuse1 && !will_use_privateuse1) {
will_use_privateuse1 =
fn->stream(c10::DeviceType::PrivateUse1).has_value();
} }
for (const auto& edge : fn->next_edges()) { for (const auto& edge : fn->next_edges()) {
if (auto next_ptr = edge.function.get()) { if (auto next_ptr = edge.function.get()) {
@ -1173,11 +1192,21 @@ auto Engine::compute_dependencies(
} }
} }
if (will_use_accelerator) { if (will_use_cuda) {
// Collects current streams for devices where this process has a // Collects current streams for CUDA/ROCM devices where this process has a
// context, so GraphTask::exec_post_processing can sync them with // context, so GraphTask::exec_post_processing can sync them with
// leaf_streams. // leaf_streams.
task.stash_current_streams(); task.stash_current_cuda_streams();
}
// Assume that two devices will not be used simultaneously.
TORCH_CHECK(
!(will_use_cuda && will_use_privateuse1),
"CUDA and privateuse1 cannot be used simultaneously in streaming backwards");
if (will_use_privateuse1) {
// Collects current streams for privateuse1.
task.stash_current_privateuse1_streams();
} }
} }
@ -1267,7 +1296,12 @@ auto Engine::execute(
auto input = inputs.at(0); auto input = inputs.at(0);
const auto input_stream = InputMetadata(input).stream(); const auto input_stream = InputMetadata(input).stream();
auto opt_next_stream = root_edges.at(0).function->stream(); auto opt_next_stream =
root_edges.at(0).function->stream(c10::DeviceType::CUDA);
if (!opt_next_stream.has_value()) {
opt_next_stream =
root_edges.at(0).function->stream(c10::DeviceType::PrivateUse1);
}
input_buffer.add( input_buffer.add(
root_edges.at(0).input_nr, root_edges.at(0).input_nr,
std::move(input), std::move(input),
@ -1535,15 +1569,15 @@ void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
thread_pool_shared_->work_.notify_one(); thread_pool_shared_->work_.notify_one();
} }
// Remembers current streams on all devices where a context has been created for // Remembers current streams on all CUDA/ROCM devices where a context has been
// This function assumes the accelerator device is available. // created. Only called if Engine::execute detects at least one node runs on a
void GraphTask::stash_current_streams() { // cuda/rocm stream.
const auto accelerator = at::getAccelerator(true).value(); void GraphTask::stash_current_cuda_streams() {
const auto guard = c10::impl::VirtualGuardImpl{accelerator}; const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
auto num_devices = guard.deviceCount(); auto num_gpus = guard.deviceCount();
caller_current_streams_.resize(num_devices); caller_current_streams_.resize(num_gpus);
if (num_devices > 0) { if (num_gpus > 0) {
for (c10::DeviceIndex idx = 0; idx < num_devices; idx++) { for (c10::DeviceIndex idx = 0; idx < num_gpus; idx++) {
#if defined(USE_ROCM) && (ROCM_VERSION < 50000) #if defined(USE_ROCM) && (ROCM_VERSION < 50000)
// If the build targets ROCM, stash streams for all visible devices // If the build targets ROCM, stash streams for all visible devices
// unconditionally, to work around // unconditionally, to work around
@ -1552,10 +1586,28 @@ void GraphTask::stash_current_streams() {
// https://github.com/pytorch/pytorch/issues/59750 is fixed. // https://github.com/pytorch/pytorch/issues/59750 is fixed.
if (true) { if (true) {
#else #else
if (at::globalContext().getAcceleratorHooksInterface().hasPrimaryContext( if (at::detail::getCUDAHooks().hasPrimaryContext(idx)) {
idx)) {
#endif #endif
caller_current_streams_[idx] = guard.getStream({accelerator, idx}); caller_current_streams_[idx] =
guard.getStream({c10::DeviceType::CUDA, idx});
} else {
caller_current_streams_[idx] = c10::nullopt;
}
}
}
}
// Remembers current streams on all devices where a context has been created for
// privateuse1
void GraphTask::stash_current_privateuse1_streams() {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::PrivateUse1};
auto num_devices = guard.deviceCount();
caller_current_streams_.resize(num_devices);
if (num_devices > 0) {
for (c10::DeviceIndex idx = 0; idx < num_devices; idx++) {
if (at::GetPrivateUse1HooksInterface()->hasPrimaryContext(idx)) {
caller_current_streams_[idx] =
guard.getStream({c10::DeviceType::PrivateUse1, idx});
} else { } else {
caller_current_streams_[idx] = c10::nullopt; caller_current_streams_[idx] = c10::nullopt;
} }
@ -1675,4 +1727,5 @@ void GraphTask::init_to_execute(
} }
} }
} // namespace torch::autograd } // namespace autograd
} // namespace torch

View File

@ -239,13 +239,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
* elements are on different devices (across multiple GPUs, for example) * elements are on different devices (across multiple GPUs, for example)
* they may have different streams. * they may have different streams.
*/ */
c10::optional<c10::Stream> stream() { c10::optional<c10::Stream> stream(const c10::DeviceType device_type) {
auto opt_device_type = at::getAccelerator();
if (!opt_device_type.has_value()) {
return c10::nullopt;
}
for (const auto& metadata : input_metadata_) { for (const auto& metadata : input_metadata_) {
if (metadata.device().type() == opt_device_type.value()) if (metadata.device().type() == device_type)
return metadata.stream(); return metadata.stream();
} }

View File

@ -127,8 +127,11 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
// These will be synced with leaf_streams in exec_post_processing. // These will be synced with leaf_streams in exec_post_processing.
std::vector<c10::optional<c10::Stream>> caller_current_streams_; std::vector<c10::optional<c10::Stream>> caller_current_streams_;
// Collects caller_current_streams_ for the accelerator device. // Collects caller_current_streams_ for cuda/rocm.
void stash_current_streams(); void stash_current_cuda_streams();
// Collects caller_current_streams_ for privateuse1.
void stash_current_privateuse1_streams();
void init_to_execute( void init_to_execute(
Node& graph_root, Node& graph_root,

View File

@ -90,7 +90,8 @@ void DistAutogradContext::accumulateGrad(
// CUDA stream restoration from autograd function. Hence, we manually // CUDA stream restoration from autograd function. Hence, we manually
// call it here to get the streams correct. // call it here to get the streams correct.
auto forward_stream = auto forward_stream =
torch::autograd::impl::grad_accumulator(variable)->stream(); torch::autograd::impl::grad_accumulator(variable)->stream(
grad.device().type());
c10::OptionalStreamGuard stream_guard(forward_stream); c10::OptionalStreamGuard stream_guard(forward_stream);
// No higher order gradients supported in distributed autograd. // No higher order gradients supported in distributed autograd.

View File

@ -219,7 +219,8 @@ void DistEngine::computeDependencies(
queue.push(mapEntry.second.get()); queue.push(mapEntry.second.get());
} }
bool will_use_accelerator = false; bool might_use_cuda = at::globalContext().hasCUDA();
bool will_use_cuda = false;
edge_list recvBackwardEdges; edge_list recvBackwardEdges;
// Traverse the graph. // Traverse the graph.
@ -228,8 +229,8 @@ void DistEngine::computeDependencies(
auto fn = queue.front(); auto fn = queue.front();
queue.pop(); queue.pop();
if (!will_use_accelerator) { if (might_use_cuda && !will_use_cuda) {
will_use_accelerator = fn->stream().has_value(); will_use_cuda = fn->stream(c10::DeviceType::CUDA).has_value();
} }
for (const auto& edge : fn->next_edges()) { for (const auto& edge : fn->next_edges()) {
@ -268,11 +269,11 @@ void DistEngine::computeDependencies(
} }
} }
if (will_use_accelerator) { if (will_use_cuda) {
// Collects current streams for CUDA/ROCM devices where this process has a // Collects current streams for CUDA/ROCM devices where this process has a
// context, so graphTask::exec_post_processing can sync them with // context, so graphTask::exec_post_processing can sync them with
// leaf_streams. // leaf_streams.
graphTask->stash_current_streams(); graphTask->stash_current_cuda_streams();
} }
// Now lets compute which functions need to be executed. The algorithm is as // Now lets compute which functions need to be executed. The algorithm is as
@ -460,7 +461,8 @@ c10::intrusive_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
// inputs might have been retrieved over the wire on a separate stream and the // inputs might have been retrieved over the wire on a separate stream and the
// sendFunction itself runs on a different stream. As a result, we need to // sendFunction itself runs on a different stream. As a result, we need to
// manually synchronize those two streams here. // manually synchronize those two streams here.
const auto& send_backward_stream = sendFunction->stream(); const auto& send_backward_stream =
sendFunction->stream(c10::DeviceType::CUDA);
if (send_backward_stream) { if (send_backward_stream) {
for (const auto& grad : sendFunction->getGrads()) { for (const auto& grad : sendFunction->getGrads()) {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};