Files
pytorch/torch/csrc/distributed/autograd/engine/dist_engine.cpp
2025-10-10 17:29:29 +00:00

641 lines
25 KiB
C++

#include <queue>
#include <ATen/Parallel.h>
#include <c10/core/Event.h>
#include <c10/util/DeadlockDetection.h>
#include <c10/util/irange.h>
#include <c10/util/thread_name.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
namespace torch::distributed::autograd {
using torch::autograd::AccumulateGrad;
using torch::autograd::edge_list;
using torch::autograd::Engine;
using torch::autograd::GraphRoot;
using torch::autograd::GraphTask;
using torch::autograd::GraphTaskGuard;
using torch::autograd::InputBuffer;
using torch::autograd::Node;
using torch::autograd::NodeTask;
using torch::autograd::ReadyQueue;
using torch::autograd::validate_outputs;
using torch::autograd::variable_list;
static constexpr const char* kNumBackwardPasses = "num_current_backward_passes";
static constexpr const char* kNumAutogradContexts = "num_autograd_contexts";
// This hook does 3 things:
// 1. Call pre hooks of the original AccumulateGrad to modify the input grad.
// 2. Accumurate the guard to RPC context.
// 3. Call post hooks of the original AccumulateGrad.
class DistAccumulateGradCaptureHook
: public GraphTask::ExecInfo::Capture::GradCaptureHook {
public:
DistAccumulateGradCaptureHook(
std::shared_ptr<AccumulateGrad> accumulateGrad,
ContextPtr autogradContext)
: accumulateGrad_(std::move(accumulateGrad)),
autogradContext_(std::move(autogradContext)) {}
at::Tensor operator()(const at::Tensor& grad) override {
ThreadLocalDistAutogradContext contextGuard{ContextPtr(autogradContext_)};
variable_list inputGrads = {grad};
// It's intended that pre/post hooks are still called even if the grad is
// undefined here.
for (const auto& hook : accumulateGrad_->pre_hooks()) {
inputGrads = (*hook)(inputGrads);
}
// It is possible that the grad is not defined since a separate
// invocation of the autograd engine on the same node might actually
// compute this gradient.
if (inputGrads[0].defined()) {
// There are 3 internal references to 'inputGrads[0]' at this moment:
// 1. 'inputGrads[0]' in this function.
// 2. 'graph_task->captured_vars_' on the callsite in the local engine.
// 3. 'InputBuffer& inputs' on the callsite as the inputs of the
// function node.
autogradContext_->accumulateGrad(
accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);
}
const variable_list kEmptyOutput;
for (const auto& hook : accumulateGrad_->post_hooks()) {
(*hook)(kEmptyOutput, inputGrads);
}
return std::move(inputGrads[0]);
}
private:
std::shared_ptr<AccumulateGrad> accumulateGrad_;
ContextPtr autogradContext_;
};
void DistEngine::globalCpuThread(
const std::shared_ptr<ReadyQueue>& ready_queue) {
c10::setThreadName("pt_dist_engine");
while (true) {
NodeTask task = ready_queue->pop();
if (task.isShutdownTask_) {
// Need to shutdown this thread.
C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
break;
}
auto graphTask = task.base_.lock();
if (graphTask == nullptr) {
// GraphTask has expired, ignore and continue processing.
continue;
}
// Launch the execution on a JIT thread.
at::launch([this,
graphTask,
graphRoot = task.fn_,
variables =
InputBuffer::variables(std::move(task.inputs_))]() mutable {
InputBuffer inputs(variables.size());
for (const auto i : c10::irange(variables.size())) {
inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt);
}
execute_graph_task_until_ready_queue_empty(
/*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
/*incrementOutstandingTasks*/ false);
});
}
}
DistEngine::DistEngine()
: engine_(Engine::get_default_engine()),
global_cpu_ready_queue_(std::make_shared<ReadyQueue>()),
global_cpu_thread_(
&DistEngine::globalCpuThread,
this,
global_cpu_ready_queue_) {
// Note [GPU to CPU continuations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// Initialize a single CPU thread to execute continuations from GPU
// tasks. The multithreaded structure for the distributed engine works
// well only for CPU tasks. If we have an order of tasks like
// CPU->GPU->CPU, distributed autograd has no thread to execute the last
// CPU task on. To fix this, we introduce a global CPU thread to handle
// such situations and it will be responsible for executing these CPU
// tasks. The CPU thread has its own ready_queue which is used as the
// cpu_ready_queue for all GraphTasks for DistEngine. This ensures all GPU
// to CPU continuations are enqueued on this thread. The global CPU thread
// simply dequeues tasks from the global queue and calls
// "execute_graph_task_until_ready_queue_empty" on a JIT thread to execute the
// appropriate task.
global_cpu_thread_.detach();
}
DistEngine::~DistEngine() {
// Ensure we shutdown the CPU thread.
TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP();
global_cpu_ready_queue_->pushShutdownTask();
global_cpu_thread_.join();
}
DistEngine& DistEngine::getInstance() {
// Leaky singleton to avoid module destructor race.
static DistEngine* engine = new DistEngine();
return *engine;
}
void DistEngine::validateRootsAndRetrieveEdges(
const variable_list& roots,
edge_list& rootEdges,
variable_list& grads) {
TORCH_CHECK(!roots.empty(), "No tensors provided for gradient computation.");
TORCH_INTERNAL_ASSERT(rootEdges.empty());
TORCH_INTERNAL_ASSERT(grads.empty());
// Verify roots are all scalar and require gradients.
for (const auto& root : roots) {
TORCH_CHECK(root.requires_grad(), "requires_grad not set on root");
TORCH_CHECK(
root.numel() == 1,
root.name(),
" is not a scalar, all roots need to be scalar");
TORCH_CHECK(
root.grad_fn(),
root.name(),
" does not have a valid gradient function.");
// Compute the root edges and generate the appropriate gradients.
rootEdges.push_back(torch::autograd::impl::gradient_edge(root));
grads.push_back(at::ones_like(root, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
// Validate rootEdges and grads.
validate_outputs(
rootEdges, grads, [](const std::string& msg) { return msg; });
}
void DistEngine::computeDependencies(
const ContextPtr& autogradContext,
const edge_list& rootEdges,
const variable_list& grads,
const std::shared_ptr<Node>& graphRoot,
edge_list& outputEdges,
bool retainGraph) {
TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
// Store root nodes so we can traverse through the graph later
// e.g., for get_current_graph_task_execution_order
c10::SmallVector<Node*, 4> temp_roots{rootEdges.size()};
for (const auto i : c10::irange(rootEdges.size())) {
temp_roots[i] = rootEdges[i].function.get();
}
// Build the graph task and graph root.
// NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask
// as we use execute_graph_task_until_ready_queue_empty, which will build
// a separate ReadyQueue for each call.
auto graphTask = std::make_shared<GraphTask>(
/* keep_graph */ retainGraph,
/* create_graph */ false,
/* depth */ 0,
/* cpu_ready_queue */ global_cpu_ready_queue_,
/* graph_roots */ temp_roots,
/* exit_on_error */ true);
// Run BFS to traverse the graph locally. The roots of the graph are
// GraphRoot and all send functions for this autograd context.
std::unordered_set<Node*> seen;
std::queue<Node*> queue;
queue.push(static_cast<Node*>(graphRoot.get()));
auto sendFunctions = autogradContext->sendFunctions();
// Add all the send functions to the queue as roots.
for (const auto& mapEntry : sendFunctions) {
// Increment 'outstanding_tasks_' for GraphTask for each send_function
// since we want the local autograd engine to wait for all of them.
graphTask->outstanding_tasks_++;
queue.push(mapEntry.second.get());
}
bool will_use_accelerator = false;
edge_list recvBackwardEdges;
// Traverse the graph.
auto& dependencies = graphTask->dependencies_;
while (!queue.empty()) {
auto fn = queue.front();
queue.pop();
if (!will_use_accelerator) {
will_use_accelerator = fn->stream().has_value();
}
for (const auto& edge : fn->next_edges()) {
if (auto nextFn = edge.function.get()) {
dependencies[nextFn] += 1;
const bool wasInserted = seen.insert(nextFn).second;
if (wasInserted) {
// Seeing this function for the first time.
queue.push(nextFn);
if (nextFn->next_edges().empty()) {
TORCH_INTERNAL_ASSERT(
dynamic_cast<AccumulateGrad*>(nextFn) ||
dynamic_cast<RecvRpcBackward*>(nextFn));
// We have found a leaf node which should be either AccumulateGrad
// or RecvRpcBackward. Record the function
// to ensure we don't execute it and instead accumulate the grads on
// the autograd context. These functions would be passed in as the
// 'outputs' parameter of the vanilla autograd engine.
// We don't accumulate any grads in the context for RecvRpcBackward.
// RecvRpcBackward is added as an output edge to indicate it is a
// leaf node and this helps in properly computing dependencies for
// the local autograd graph. Putting RecvRpcBackward in
// 'outputEdges' means that this function needs to be executed
// (inline with our assumption for FAST mode that all send/recv
// functions are valid in the backward pass), and as a result all of
// its ancestors need to be executed as well.
if (dynamic_cast<RecvRpcBackward*>(nextFn)) {
recvBackwardEdges.emplace_back(edge);
}
outputEdges.emplace_back(edge);
}
}
}
}
}
if (will_use_accelerator) {
// Collects current streams for CUDA/ROCM devices where this process has a
// context, so graphTask::exec_post_processing can sync them with
// leaf_streams.
graphTask->stash_current_streams();
}
// Now lets compute which functions need to be executed. The algorithm is as
// follows:
// 1. Create a dummy GraphRoot which points to all 'send' functions for this
// context and the original graphRoot. Run 'init_to_execute' with the
// outputEdges and the dummy GraphRoot. This ensures we mark
// appropriate functions as needed if they are reachable only from a
// specific 'send' function locally and not necessarily from the provided
// roots.
// 2. For all edges in 'outputEdges' which point to 'RecvRpcBackward', mark
// those functions as needed for execution. The reason for this is that
// 'init_to_execute', will mark these as not needed. But 'RecvRpcBackward'
// is unique in the sense that we use it as a leaf node in graph to compute
// needed execution accurately, but unlike AccumulateGrad, we do need to
// execute this function.
if (!outputEdges.empty()) {
// Compute 'needed execution' starting from all 'send' functions and the
// original graphRoot.
edge_list edges;
// Create some dummy edges (input_nr not important for init_to_execute).
for (const auto& mapEntry : sendFunctions) {
edges.emplace_back(mapEntry.second, 0);
}
// Add the original graphRoot as an edge.
edges.emplace_back(graphRoot, 0);
// Create a dummy GraphRoot and run init_to_execute with it.
GraphRoot dummyRoot(edges, {});
graphTask->init_to_execute(
dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0);
for (auto& mapEntry : graphTask->exec_info_) {
auto& execInfo = mapEntry.second;
if (!execInfo.captures_) {
continue;
}
auto fn = mapEntry.first;
// There may be nodes other than 'AccumulateGrad', e.g. RecvRPCBackward,
// to be captured.
if (auto accumulateGradFn = dynamic_cast<AccumulateGrad*>(fn)) {
for (auto& capture : *execInfo.captures_) {
// Capture hooks are technically deprecated, but as an exception below
// is the single and only instance of capture hooks usage that we
// support. See NOTE [Deprecated capture hooks] for more context.
capture.DO_NOT_USE_DEPRECATED_register_capture_hook(
std::make_unique<DistAccumulateGradCaptureHook>(
std::dynamic_pointer_cast<AccumulateGrad>(
accumulateGradFn->shared_from_this()),
autogradContext));
}
}
}
// Mark all 'RecvRPCBackward' as needing execution.
for (const auto& recvBackwardEdge : recvBackwardEdges) {
graphTask->exec_info_[recvBackwardEdge.function.get()].needed_ = true;
}
}
// Set graph task owner in a single thread since concurrent access to
// 'owner_' field is not permitted.
graphTask->owner_ = torch::autograd::CPU_DEVICE;
// Let autograd context take ownership of the GraphTask.
autogradContext->setGraphTask(std::move(graphTask));
}
void DistEngine::execute_graph_task_until_ready_queue_empty(
NodeTask&& node_task,
bool incrementOutstandingTasks) {
engine_.initialize_device_threads_pool();
// Create a ready queue per call to traverse the graph_task from
// root_to_execute This allow concurrent execution of the same GraphTask from
// different threads
std::shared_ptr<ReadyQueue> cpu_ready_queue = std::make_shared<ReadyQueue>();
auto graph_task = node_task.base_.lock();
if (graph_task == nullptr) {
LOG(ERROR) << "GraphTask has expired for NodeTask: "
<< node_task.fn_->name() << ", skipping execution.";
return;
}
cpu_ready_queue->push(std::move(node_task), incrementOutstandingTasks);
torch::autograd::set_device(torch::autograd::CPU_DEVICE);
while (!cpu_ready_queue->empty()) {
std::shared_ptr<GraphTask> local_graph_task;
{
// Scope this block of execution since NodeTask is not needed after this
// block and can be deallocated (release any references to grad tensors
// as part of inputs_)
NodeTask task = cpu_ready_queue->pop();
local_graph_task = task.base_.lock();
if (!local_graph_task) {
continue;
}
if (task.fn_ && !local_graph_task->has_error_.load()) {
at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
try {
GraphTaskGuard guard(local_graph_task);
engine_.evaluate_function(
local_graph_task, task.fn_.get(), task.inputs_, cpu_ready_queue);
} catch (std::exception& e) {
engine_.thread_on_exception(local_graph_task, task.fn_, e);
// break the loop in error so that we immediately stop the execution
// of this GraphTask, mark it completed if necessary and return the
// future with proper ErrorMessage
break;
}
}
}
// Decrement the outstanding task.
--local_graph_task->outstanding_tasks_;
}
// Check if we've completed execution.
if (graph_task->completed()) {
// We don't need to explicitly notify the owner thread, since
// 'mark_as_completed_and_run_post_processing' would mark the Future as
// completed and this would notify the owner thread that the task has been
// completed.
graph_task->mark_as_completed_and_run_post_processing();
}
}
c10::intrusive_ptr<c10::ivalue::Future> DistEngine::
runEngineAndAccumulateGradients(
const ContextPtr& autogradContext,
const std::shared_ptr<Node>& graphRoot,
const edge_list& outputEdges,
bool incrementOutstandingTasks) {
// Cleanup previous state for outstanding RPCs. Outstanding RPCs could be
// lingering if we're running backward multiple times and some of the
// passes ran into errors.
autogradContext->clearOutstandingRpcs();
auto graphTask = autogradContext->retrieveGraphTask();
at::launch([this, graphTask, graphRoot, incrementOutstandingTasks]() {
execute_graph_task_until_ready_queue_empty(
/*node_task*/ NodeTask(graphTask, graphRoot, InputBuffer(0)),
/*incrementOutstandingTasks*/ incrementOutstandingTasks);
});
// Use a reference here to avoid refcount bump on futureGrads.
auto& futureGrads = graphTask->future_result_;
// Build a future that waits for the callbacks to execute (since callbacks
// execute after the original future is completed). This ensures we return a
// future that waits for all gradient accumulation to finish.
auto accumulateGradFuture =
c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture](
c10::ivalue::Future& futureGrads) {
if (futureGrads.hasError()) {
// Don't accumulate gradients if we receive an error.
// We must add the node information here since DistEngine::execute
// waits on accumulateGradFuture and will throw an exception once we
// set the error below.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
futureGrads.tryRetrieveErrorMessage());
accumulateGradFuture->setError(std::make_exception_ptr(
c10::ivalue::Future::FutureError(std::move(errorMsg))));
return;
}
try {
const variable_list& grads = futureGrads.constValue().toTensorVector();
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
accumulateGradFuture->markCompleted(c10::IValue());
} catch (std::exception& e) {
accumulateGradFuture->setErrorIfNeeded(std::current_exception());
}
});
return accumulateGradFuture;
}
c10::intrusive_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
const ContextPtr& autogradContext,
const std::shared_ptr<SendRpcBackward>& sendFunction,
bool retainGraph) {
// Typically the local autograd engine ensures stream synchronizations between
// nodes in the graph. However, for distributed autograd the sendFunction
// inputs might have been retrieved over the wire on a separate stream and the
// sendFunction itself runs on a different stream. As a result, we need to
// manually synchronize those two streams here.
const auto& send_backward_stream = sendFunction->stream();
if (send_backward_stream) {
for (const auto& grad : sendFunction->getGrads()) {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
const auto default_stream = guard.getStream(grad.device());
if (send_backward_stream != default_stream) {
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(default_stream);
send_backward_stream->wait(event);
}
}
}
std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
if (initializedContextIds_.find(autogradContext->contextId()) ==
initializedContextIds_.end()) {
edge_list outputEdges;
// Pass in a dummy graphRoot since all send functions are the roots.
auto dummyRoot = std::make_shared<GraphRoot>(edge_list(), variable_list());
computeDependencies(
autogradContext, {}, {}, dummyRoot, outputEdges, retainGraph);
// Mark the autograd context id as initialized and unlock.
initializedContextIds_.insert(autogradContext->contextId());
lock.unlock();
// Enqueue the current send function.
auto graphTask = autogradContext->retrieveGraphTask();
// Run the autograd engine.
auto accumulateGradFuture = runEngineAndAccumulateGradients(
autogradContext,
sendFunction,
outputEdges,
/*incrementOutstandingTasks=*/false);
// Build the 'uber' future that waits for everything.
auto callbackFuture =
c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
accumulateGradFuture->addCallback(
[autogradContext,
callbackFuture](c10::ivalue::Future& accumulateGradFuture) {
try {
if (accumulateGradFuture.hasError()) {
// Perform cleanup at the end of the backward pass (before we mark
// the future as completed).
DistEngine::getInstance().cleanupBackwardPass(autogradContext);
// Skip any further processing on errors.
callbackFuture->setError(accumulateGradFuture.exception_ptr());
return;
}
// Wait for all RPCs after the autograd engine is done.
auto rpcFuture =
autogradContext->clearAndWaitForOutstandingRpcsAsync();
rpcFuture->addCallback([callbackFuture, autogradContext](
c10::ivalue::Future& rpcFuture) {
try {
// Perform cleanup at the end of the backward pass (before
// we mark the future as completed).
DistEngine::getInstance().cleanupBackwardPass(autogradContext);
} catch (std::exception& e) {
callbackFuture->setErrorIfNeeded(std::current_exception());
return;
}
// Finally mark the 'uber' future as completed.
if (!rpcFuture.hasError()) {
callbackFuture->markCompleted(c10::IValue());
} else {
callbackFuture->setError(rpcFuture.exception_ptr());
}
});
} catch (std::exception& e) {
callbackFuture->setErrorIfNeeded(std::current_exception());
}
});
// Return the future which waits for all async processing to be done.
return callbackFuture;
} else {
lock.unlock();
auto graphTask = autogradContext->retrieveGraphTask();
at::launch([this, graphTask, sendFunction]() {
execute_graph_task_until_ready_queue_empty(
/*node_task*/ NodeTask(graphTask, sendFunction, InputBuffer(0)),
/*incrementOutstandingTasks*/ false);
});
auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
fut->markCompleted(c10::IValue());
return fut;
}
}
void DistEngine::execute(
int64_t contextId,
const variable_list& roots,
bool retainGraph) {
// Retrieve the context for the given context_id. This will throw if the
// context_id is invalid.
auto autogradContext =
DistAutogradContainer::getInstance().retrieveContext(contextId);
// Perform initial pre-processing.
edge_list rootEdges;
variable_list grads;
validateRootsAndRetrieveEdges(roots, rootEdges, grads);
std::shared_ptr<Node> graphRoot =
std::make_shared<GraphRoot>(rootEdges, grads);
edge_list outputEdges;
// Compute dependencies locally, starting from all roots and all 'send'
// functions.
{
std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
// Context should not have been initialized already.
TORCH_INTERNAL_ASSERT(
initializedContextIds_.find(autogradContext->contextId()) ==
initializedContextIds_.end());
computeDependencies(
autogradContext, rootEdges, grads, graphRoot, outputEdges, retainGraph);
// Mark the autograd context id as initialized.
initializedContextIds_.insert(autogradContext->contextId());
}
BackwardPassCleanupGuard guard(autogradContext);
// This needs to be blocking and as a result we wait for the future to
// complete.
runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges)
->waitAndThrow();
// Wait for all of the outstanding rpcs to complete.
autogradContext->clearAndWaitForOutstandingRpcsAsync()->waitAndThrow();
}
void DistEngine::cleanupBackwardPass(const ContextPtr& autogradContext) {
// Validate only the GraphTask is holding a reference to the Future
// which holds gradients for the backward pass. This ensures that
// after 'resetGraphTask' is called below, there are no remaining
// references left to the gradients for the backward pass.
//
// This ensures our 'use_count' checks in
// AccumulateGrad::accumulateGrad are correct and we're
// not leaking any references to the gradients anywhere else.
const auto& futureGrads =
autogradContext->retrieveGraphTask()->future_result_;
TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1);
// Reset the graph task once we're done with all processing.
autogradContext->resetGraphTask();
// Clear any outstanding rpcs.
autogradContext->clearOutstandingRpcs();
// Clear the context id once we're done with the autograd engine
// processing.
std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
initializedContextIds_.erase(autogradContext->contextId());
}
size_t DistEngine::numBackwardPasses() const {
std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
return initializedContextIds_.size();
}
std::unordered_map<std::string, int64_t> DistEngine::getDebugInfo() const {
std::unordered_map<std::string, int64_t> debugInfo;
debugInfo[kNumBackwardPasses] = static_cast<int64_t>(numBackwardPasses());
debugInfo[kNumAutogradContexts] = static_cast<int64_t>(
DistAutogradContainer::getInstance().numAutogradContexts());
return debugInfo;
}
} // namespace torch::distributed::autograd