mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add Accelerator device and shell hooks (#119329)"
This reverts commit 4b9568a360c4a90220e78e43435be8c56bc33fb2. Reverted https://github.com/pytorch/pytorch/pull/119329 on behalf of https://github.com/huydhn due to Breaks internal build and requires OSS file update to fix it ([comment](https://github.com/pytorch/pytorch/pull/119329#issuecomment-1940278598))
This commit is contained in:
@ -1,13 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/CPUGeneratorImpl.h>
|
||||
#include <ATen/DeviceAccelerator.h>
|
||||
#include <ATen/LinalgBackend.h>
|
||||
#include <ATen/core/ATenGeneral.h>
|
||||
#include <ATen/core/DeprecatedTypeProperties.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/detail/HIPHooksInterface.h>
|
||||
#include <ATen/detail/IPUHooksInterface.h>
|
||||
@ -58,22 +56,6 @@ class TORCH_API Context {
|
||||
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
||||
}
|
||||
}
|
||||
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
|
||||
c10::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
|
||||
c10::DeviceType device_type = opt_device_type.has_value()
|
||||
? opt_device_type.value()
|
||||
: at::getAccelerator(true).value();
|
||||
if (device_type == at::kCUDA) {
|
||||
return at::detail::getCUDAHooks();
|
||||
} else if (device_type == at::kMPS) {
|
||||
return at::detail::getMPSHooks();
|
||||
} else if (device_type == at::kPrivateUse1) {
|
||||
return at::detail::getPrivateUse1Hooks();
|
||||
} else {
|
||||
AT_ERROR(
|
||||
c10::DeviceTypeName(device_type), " device type not an accelerator.");
|
||||
}
|
||||
}
|
||||
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
|
||||
initCUDAIfNeeded(device_type);
|
||||
initHIPIfNeeded(device_type);
|
||||
|
@ -1,31 +0,0 @@
|
||||
#include <ATen/DeviceAccelerator.h>
|
||||
#include <ATen/Context.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
C10_API std::optional<DeviceType> getAccelerator(bool checked) {
|
||||
#define CHECK_NO_CUDA \
|
||||
TORCH_CHECK(!at::hasCUDA(), "Cannot have both CUDA and PrivateUse1");
|
||||
|
||||
#define CHECK_NO_PU1 \
|
||||
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
|
||||
|
||||
if (is_privateuse1_backend_registered()) {
|
||||
// We explicitly allow PrivateUse1 and another device at the same time
|
||||
// as we use this for testing.
|
||||
// Whenever a PrivateUse1 device is registered, use it first.
|
||||
return kPrivateUse1;
|
||||
} else if (at::hasCUDA()) {
|
||||
CHECK_NO_PU1
|
||||
return kCUDA;
|
||||
} else {
|
||||
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
#undef CHECK_NO_CUDA
|
||||
#undef CHECK_NO_PU1
|
||||
}
|
||||
|
||||
|
||||
} // namespace at
|
@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <ATen/detail/MTIAHooksInterface.h>
|
||||
#include <optional>
|
||||
|
||||
// This file defines the top level Accelerator concept for PyTorch.
|
||||
// A device is an accelerator per the definition here if:
|
||||
// - It is mutually exclusive with all other accelerators
|
||||
// - It performs asynchronous compute via a Stream/Event system
|
||||
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
|
||||
//
|
||||
// As of today, accelerator devices are (in no particular order):
|
||||
// CUDA, MTIA, PrivateUse1
|
||||
// We want to add once all the proper APIs are supported and tested:
|
||||
// HIP, MPS, XPU
|
||||
|
||||
namespace at {
|
||||
|
||||
// Ensures that only one accelerator is available (at
|
||||
// compile time if possible) and return it.
|
||||
// When checked is true, the returned optional always has a value.
|
||||
TORCH_API std::optional<DeviceType> getAccelerator(bool checked = false);
|
||||
|
||||
} // namespace at
|
@ -1,21 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
// AcceleratorHooksInterface is a shared interface provided by all
|
||||
// accelerators to allow generic code.
|
||||
// This inferface is hook-based as it corresponds to all the functions
|
||||
// that are going to be called in a generic way from the CPU code.
|
||||
|
||||
struct TORCH_API AcceleratorHooksInterface {
|
||||
// This should never actually be implemented, but it is used to
|
||||
// squelch -Werror=non-virtual-dtor
|
||||
virtual ~AcceleratorHooksInterface() = default;
|
||||
|
||||
// Whether the device at device_index is fully initialized or not.
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
|
||||
};
|
||||
|
||||
} // namespace at
|
@ -4,8 +4,6 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
// Forward-declares at::Generator and at::cuda::NVRTC
|
||||
namespace at {
|
||||
struct Generator;
|
||||
@ -59,7 +57,7 @@ constexpr const char* CUDA_HELP =
|
||||
// TODO: Consider putting the stub definitions in another class, so that one
|
||||
// never forgets to implement each virtual function in the real implementation
|
||||
// in CUDAHooks. This probably doesn't buy us much though.
|
||||
struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
struct TORCH_API CUDAHooksInterface {
|
||||
// This should never actually be implemented, but it is used to
|
||||
// squelch -Werror=non-virtual-dtor
|
||||
virtual ~CUDAHooksInterface() = default;
|
||||
@ -109,7 +107,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const {
|
||||
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,6 @@
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
@ -12,7 +11,7 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
|
||||
struct TORCH_API MPSHooksInterface {
|
||||
// this fails the implementation if MPSHooks functions are called, but
|
||||
// MPS backend is not present.
|
||||
#define FAIL_MPSHOOKS_FUNC(func) \
|
||||
@ -87,9 +86,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
|
||||
virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
#undef FAIL_MPSHOOKS_FUNC
|
||||
};
|
||||
|
||||
|
@ -4,8 +4,6 @@
|
||||
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace at {
|
||||
@ -19,7 +17,7 @@ constexpr const char* MTIA_HELP =
|
||||
"this error has occurred because you are trying "
|
||||
"to use some MTIA's functionality without MTIA extension included.";
|
||||
|
||||
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
struct TORCH_API MTIAHooksInterface {
|
||||
virtual ~MTIAHooksInterface() = default;
|
||||
|
||||
virtual void initMTIA() const {
|
||||
@ -39,14 +37,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
"Cannot query detailed MTIA version without MTIA Extension for PyTorch.",
|
||||
MTIA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot check MTIA primary context without MTIA Extension for PyTorch.",
|
||||
MTIA_HELP);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct TORCH_API MTIAHooksArgs {};
|
||||
|
@ -22,15 +22,4 @@ TORCH_API bool isPrivateUse1HooksRegistered() {
|
||||
return privateuse1_hooks != nullptr;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks() {
|
||||
TORCH_CHECK(
|
||||
privateuse1_hooks != nullptr,
|
||||
"Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.");
|
||||
return *privateuse1_hooks;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace at
|
||||
|
@ -1,14 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/util/Exception.h>
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
struct TORCH_API PrivateUse1HooksInterface {
|
||||
virtual ~PrivateUse1HooksInterface() = default;
|
||||
virtual const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) {
|
||||
@ -29,7 +28,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
|
||||
}
|
||||
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
|
||||
@ -52,10 +51,4 @@ TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface();
|
||||
|
||||
TORCH_API bool isPrivateUse1HooksRegistered();
|
||||
|
||||
namespace detail {
|
||||
|
||||
TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace at
|
||||
|
@ -46,12 +46,6 @@ struct MPSHooks : public at::MPSHooksInterface {
|
||||
void synchronizeEvent(uint32_t event_id) const override;
|
||||
bool queryEvent(uint32_t event_id) const override;
|
||||
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
|
||||
|
||||
// Compatibility with Accelerator API
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
// When MPS is available, it is always in use for the one device.
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at::mps
|
||||
|
@ -970,7 +970,6 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/AccumulateType.cpp",
|
||||
"aten/src/ATen/LegacyBatchedTensorImpl.cpp",
|
||||
"aten/src/ATen/CPUGeneratorImpl.cpp",
|
||||
"aten/src/ATen/DeviceAccelerator.cpp",
|
||||
"aten/src/ATen/Context.cpp",
|
||||
"aten/src/ATen/DLConvertor.cpp",
|
||||
"aten/src/ATen/EmptyTensor.cpp",
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/dynamo/compiled_autograd.h>
|
||||
|
||||
#include <ATen/DeviceAccelerator.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Parallel.h>
|
||||
@ -47,7 +46,8 @@
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
namespace torch::autograd {
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
namespace {
|
||||
static bool in_bad_autograd_fork =
|
||||
@ -991,7 +991,10 @@ void Engine::evaluate_function(
|
||||
// ensure they're safe to consume in the context of the present
|
||||
// func's stream (if applicable). So we guard onto that stream
|
||||
// before working with the grads in any capacity.
|
||||
auto opt_parent_stream = (*func).stream();
|
||||
auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
|
||||
if (!opt_parent_stream.has_value()) {
|
||||
opt_parent_stream = (*func).stream(c10::DeviceType::PrivateUse1);
|
||||
}
|
||||
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
|
||||
|
||||
// If exec_info_ is not empty, we have to instrument the execution
|
||||
@ -1009,7 +1012,10 @@ void Engine::evaluate_function(
|
||||
*func, InputBuffer::variables(std::move(inputs)));
|
||||
}
|
||||
if (auto* capture_vec = fn_info.captures_.get()) {
|
||||
auto opt_parent_stream = (*func).stream();
|
||||
auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
|
||||
if (!opt_parent_stream.has_value()) {
|
||||
opt_parent_stream = (*func).stream(c10::DeviceType::PrivateUse1);
|
||||
}
|
||||
// Lock mutex for writing to graph_task->captured_vars_.
|
||||
std::lock_guard<std::mutex> lock(graph_task->mutex_);
|
||||
for (const auto& capture : *capture_vec) {
|
||||
@ -1101,7 +1107,10 @@ void Engine::evaluate_function(
|
||||
InputBuffer input_buffer(next.function->num_inputs());
|
||||
|
||||
// Accumulates into buffer
|
||||
auto opt_next_stream = next.function->stream();
|
||||
auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
|
||||
if (!opt_next_stream.has_value()) {
|
||||
opt_next_stream = next.function->stream(c10::DeviceType::PrivateUse1);
|
||||
}
|
||||
input_buffer.add(
|
||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
||||
|
||||
@ -1117,7 +1126,10 @@ void Engine::evaluate_function(
|
||||
auto& input_buffer = not_ready_it->second;
|
||||
|
||||
// Accumulates into buffer
|
||||
auto opt_next_stream = next.function->stream();
|
||||
auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
|
||||
if (!opt_next_stream.has_value()) {
|
||||
opt_next_stream = next.function->stream(c10::DeviceType::PrivateUse1);
|
||||
}
|
||||
input_buffer.add(
|
||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
||||
if (is_ready) {
|
||||
@ -1149,7 +1161,10 @@ auto Engine::compute_dependencies(
|
||||
uint64_t min_topo_nr) -> void {
|
||||
// Computes the number of dependencies for each function which requires grad
|
||||
std::vector<Node*> queue{root};
|
||||
bool will_use_accelerator = false;
|
||||
bool might_use_cuda = at::globalContext().hasCUDA();
|
||||
bool might_use_privateuse1 = at::isPrivateUse1HooksRegistered();
|
||||
bool will_use_cuda = false;
|
||||
bool will_use_privateuse1 = false;
|
||||
|
||||
// Queue contains all nodes that will start propagating gradients.
|
||||
// We no longer have to expand functions that don't require grad.
|
||||
@ -1160,8 +1175,12 @@ auto Engine::compute_dependencies(
|
||||
if (fn->topological_nr() < min_topo_nr) {
|
||||
continue;
|
||||
}
|
||||
if (!will_use_accelerator) {
|
||||
will_use_accelerator = fn->stream().has_value();
|
||||
if (might_use_cuda && !will_use_cuda) {
|
||||
will_use_cuda = fn->stream(c10::DeviceType::CUDA).has_value();
|
||||
}
|
||||
if (might_use_privateuse1 && !will_use_privateuse1) {
|
||||
will_use_privateuse1 =
|
||||
fn->stream(c10::DeviceType::PrivateUse1).has_value();
|
||||
}
|
||||
for (const auto& edge : fn->next_edges()) {
|
||||
if (auto next_ptr = edge.function.get()) {
|
||||
@ -1173,11 +1192,21 @@ auto Engine::compute_dependencies(
|
||||
}
|
||||
}
|
||||
|
||||
if (will_use_accelerator) {
|
||||
// Collects current streams for devices where this process has a
|
||||
if (will_use_cuda) {
|
||||
// Collects current streams for CUDA/ROCM devices where this process has a
|
||||
// context, so GraphTask::exec_post_processing can sync them with
|
||||
// leaf_streams.
|
||||
task.stash_current_streams();
|
||||
task.stash_current_cuda_streams();
|
||||
}
|
||||
|
||||
// Assume that two devices will not be used simultaneously.
|
||||
TORCH_CHECK(
|
||||
!(will_use_cuda && will_use_privateuse1),
|
||||
"CUDA and privateuse1 cannot be used simultaneously in streaming backwards");
|
||||
|
||||
if (will_use_privateuse1) {
|
||||
// Collects current streams for privateuse1.
|
||||
task.stash_current_privateuse1_streams();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1267,7 +1296,12 @@ auto Engine::execute(
|
||||
auto input = inputs.at(0);
|
||||
|
||||
const auto input_stream = InputMetadata(input).stream();
|
||||
auto opt_next_stream = root_edges.at(0).function->stream();
|
||||
auto opt_next_stream =
|
||||
root_edges.at(0).function->stream(c10::DeviceType::CUDA);
|
||||
if (!opt_next_stream.has_value()) {
|
||||
opt_next_stream =
|
||||
root_edges.at(0).function->stream(c10::DeviceType::PrivateUse1);
|
||||
}
|
||||
input_buffer.add(
|
||||
root_edges.at(0).input_nr,
|
||||
std::move(input),
|
||||
@ -1535,15 +1569,15 @@ void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
|
||||
thread_pool_shared_->work_.notify_one();
|
||||
}
|
||||
|
||||
// Remembers current streams on all devices where a context has been created for
|
||||
// This function assumes the accelerator device is available.
|
||||
void GraphTask::stash_current_streams() {
|
||||
const auto accelerator = at::getAccelerator(true).value();
|
||||
const auto guard = c10::impl::VirtualGuardImpl{accelerator};
|
||||
auto num_devices = guard.deviceCount();
|
||||
caller_current_streams_.resize(num_devices);
|
||||
if (num_devices > 0) {
|
||||
for (c10::DeviceIndex idx = 0; idx < num_devices; idx++) {
|
||||
// Remembers current streams on all CUDA/ROCM devices where a context has been
|
||||
// created. Only called if Engine::execute detects at least one node runs on a
|
||||
// cuda/rocm stream.
|
||||
void GraphTask::stash_current_cuda_streams() {
|
||||
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
|
||||
auto num_gpus = guard.deviceCount();
|
||||
caller_current_streams_.resize(num_gpus);
|
||||
if (num_gpus > 0) {
|
||||
for (c10::DeviceIndex idx = 0; idx < num_gpus; idx++) {
|
||||
#if defined(USE_ROCM) && (ROCM_VERSION < 50000)
|
||||
// If the build targets ROCM, stash streams for all visible devices
|
||||
// unconditionally, to work around
|
||||
@ -1552,10 +1586,28 @@ void GraphTask::stash_current_streams() {
|
||||
// https://github.com/pytorch/pytorch/issues/59750 is fixed.
|
||||
if (true) {
|
||||
#else
|
||||
if (at::globalContext().getAcceleratorHooksInterface().hasPrimaryContext(
|
||||
idx)) {
|
||||
if (at::detail::getCUDAHooks().hasPrimaryContext(idx)) {
|
||||
#endif
|
||||
caller_current_streams_[idx] = guard.getStream({accelerator, idx});
|
||||
caller_current_streams_[idx] =
|
||||
guard.getStream({c10::DeviceType::CUDA, idx});
|
||||
} else {
|
||||
caller_current_streams_[idx] = c10::nullopt;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remembers current streams on all devices where a context has been created for
|
||||
// privateuse1
|
||||
void GraphTask::stash_current_privateuse1_streams() {
|
||||
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::PrivateUse1};
|
||||
auto num_devices = guard.deviceCount();
|
||||
caller_current_streams_.resize(num_devices);
|
||||
if (num_devices > 0) {
|
||||
for (c10::DeviceIndex idx = 0; idx < num_devices; idx++) {
|
||||
if (at::GetPrivateUse1HooksInterface()->hasPrimaryContext(idx)) {
|
||||
caller_current_streams_[idx] =
|
||||
guard.getStream({c10::DeviceType::PrivateUse1, idx});
|
||||
} else {
|
||||
caller_current_streams_[idx] = c10::nullopt;
|
||||
}
|
||||
@ -1675,4 +1727,5 @@ void GraphTask::init_to_execute(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::autograd
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
@ -239,13 +239,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
* elements are on different devices (across multiple GPUs, for example)
|
||||
* they may have different streams.
|
||||
*/
|
||||
c10::optional<c10::Stream> stream() {
|
||||
auto opt_device_type = at::getAccelerator();
|
||||
if (!opt_device_type.has_value()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
c10::optional<c10::Stream> stream(const c10::DeviceType device_type) {
|
||||
for (const auto& metadata : input_metadata_) {
|
||||
if (metadata.device().type() == opt_device_type.value())
|
||||
if (metadata.device().type() == device_type)
|
||||
return metadata.stream();
|
||||
}
|
||||
|
||||
|
@ -127,8 +127,11 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
|
||||
// These will be synced with leaf_streams in exec_post_processing.
|
||||
std::vector<c10::optional<c10::Stream>> caller_current_streams_;
|
||||
|
||||
// Collects caller_current_streams_ for the accelerator device.
|
||||
void stash_current_streams();
|
||||
// Collects caller_current_streams_ for cuda/rocm.
|
||||
void stash_current_cuda_streams();
|
||||
|
||||
// Collects caller_current_streams_ for privateuse1.
|
||||
void stash_current_privateuse1_streams();
|
||||
|
||||
void init_to_execute(
|
||||
Node& graph_root,
|
||||
|
@ -90,7 +90,8 @@ void DistAutogradContext::accumulateGrad(
|
||||
// CUDA stream restoration from autograd function. Hence, we manually
|
||||
// call it here to get the streams correct.
|
||||
auto forward_stream =
|
||||
torch::autograd::impl::grad_accumulator(variable)->stream();
|
||||
torch::autograd::impl::grad_accumulator(variable)->stream(
|
||||
grad.device().type());
|
||||
c10::OptionalStreamGuard stream_guard(forward_stream);
|
||||
|
||||
// No higher order gradients supported in distributed autograd.
|
||||
|
@ -219,7 +219,8 @@ void DistEngine::computeDependencies(
|
||||
queue.push(mapEntry.second.get());
|
||||
}
|
||||
|
||||
bool will_use_accelerator = false;
|
||||
bool might_use_cuda = at::globalContext().hasCUDA();
|
||||
bool will_use_cuda = false;
|
||||
|
||||
edge_list recvBackwardEdges;
|
||||
// Traverse the graph.
|
||||
@ -228,8 +229,8 @@ void DistEngine::computeDependencies(
|
||||
auto fn = queue.front();
|
||||
queue.pop();
|
||||
|
||||
if (!will_use_accelerator) {
|
||||
will_use_accelerator = fn->stream().has_value();
|
||||
if (might_use_cuda && !will_use_cuda) {
|
||||
will_use_cuda = fn->stream(c10::DeviceType::CUDA).has_value();
|
||||
}
|
||||
|
||||
for (const auto& edge : fn->next_edges()) {
|
||||
@ -268,11 +269,11 @@ void DistEngine::computeDependencies(
|
||||
}
|
||||
}
|
||||
|
||||
if (will_use_accelerator) {
|
||||
if (will_use_cuda) {
|
||||
// Collects current streams for CUDA/ROCM devices where this process has a
|
||||
// context, so graphTask::exec_post_processing can sync them with
|
||||
// leaf_streams.
|
||||
graphTask->stash_current_streams();
|
||||
graphTask->stash_current_cuda_streams();
|
||||
}
|
||||
|
||||
// Now lets compute which functions need to be executed. The algorithm is as
|
||||
@ -460,7 +461,8 @@ c10::intrusive_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
|
||||
// inputs might have been retrieved over the wire on a separate stream and the
|
||||
// sendFunction itself runs on a different stream. As a result, we need to
|
||||
// manually synchronize those two streams here.
|
||||
const auto& send_backward_stream = sendFunction->stream();
|
||||
const auto& send_backward_stream =
|
||||
sendFunction->stream(c10::DeviceType::CUDA);
|
||||
if (send_backward_stream) {
|
||||
for (const auto& grad : sendFunction->getGrads()) {
|
||||
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
|
||||
|
Reference in New Issue
Block a user