mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156321 Approved by: https://github.com/jingsh ghstack dependencies: #156313, #156314, #156315, #156316, #156317, #156319
173 lines
7.2 KiB
C++
173 lines
7.2 KiB
C++
#pragma once
|
|
|
|
#include <mutex>
|
|
#include <unordered_set>
|
|
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/functions/basic_ops.h>
|
|
#include <torch/csrc/distributed/autograd/context/context.h>
|
|
|
|
namespace torch::distributed::autograd {
|
|
|
|
// Forward declaration.
|
|
class BackwardPassCleanupGuard;
|
|
|
|
// This is a singleton class responsible for running distributed backward
|
|
// passes. This engine relies heavily on the vanilla autograd engine and tries
|
|
// to reuse it as much as possible. This class is mostly responsible for the
|
|
// distributed aspects of autograd and tries to hook into the autograd engine
|
|
// where convenient.
|
|
|
|
// Unlike the vanilla autograd engine, the distributed autograd engine
|
|
// accumulates the gradients in the appropriate DistAutogradContext. This avoids
|
|
// multiple trainer nodes stomping on each others gradients.
|
|
class TORCH_API DistEngine {
|
|
public:
|
|
// Retrieve the singleton instance.
|
|
static DistEngine& getInstance();
|
|
|
|
// Given a list of root variables, start the distributed backwards pass from
|
|
// these variables and accumulate all the gradients in the current autograd
|
|
// context on each node. This method is used to kickoff distributed autograd
|
|
// on a single node.
|
|
void execute(
|
|
int64_t context_id,
|
|
const torch::autograd::variable_list& roots,
|
|
bool retainGraph);
|
|
|
|
// Given a send function to execute in the autograd engine, ensures we compute
|
|
// dependencies once for this node and enqueues the send function for execute
|
|
// in the engine.
|
|
// This method is used to kick off the autograd computation on a node when it
|
|
// receives gradients from the corresponding 'recv' method on another node.
|
|
// The gradients are accumulated in the provided autograd context.
|
|
c10::intrusive_ptr<c10::ivalue::Future> executeSendFunctionAsync(
|
|
const ContextPtr& autogradContext,
|
|
const std::shared_ptr<SendRpcBackward>& sendFunction,
|
|
bool retainGraph);
|
|
|
|
// Number of backward passes currently running for the Distributed Engine.
|
|
size_t numBackwardPasses() const;
|
|
|
|
// Returns key-value pairs consisting of useful debugging information related
|
|
// to distributed autograd.
|
|
std::unordered_map<std::string, int64_t> getDebugInfo() const;
|
|
|
|
DistEngine(const DistEngine&) = delete;
|
|
DistEngine& operator=(const DistEngine&) = delete;
|
|
DistEngine(DistEngine&&) = delete;
|
|
DistEngine& operator=(DistEngine&&) = delete;
|
|
|
|
private:
|
|
// Make sure this is a singleton.
|
|
DistEngine();
|
|
~DistEngine();
|
|
|
|
// Validates the input roots for the backward computations and retrieves the
|
|
// appropriate root edges and corresponding gradients. Populates root_edges
|
|
// with the appropriate gradient edges and grads with the gradients for each
|
|
// edge.
|
|
void validateRootsAndRetrieveEdges(
|
|
const torch::autograd::variable_list& roots,
|
|
torch::autograd::edge_list& rootEdges,
|
|
torch::autograd::variable_list& grads);
|
|
|
|
// Given the autograd context, root edges and grads, we compute dependencies
|
|
// for the local node and fill out the provided GraphTask and GraphRoot with
|
|
// appropriate information for the local autograd engine.
|
|
// We also determine all leaf nodes(functions) in the graph and accumulate
|
|
// them in outputEdges.
|
|
void computeDependencies(
|
|
const ContextPtr& context,
|
|
const torch::autograd::edge_list& rootEdges,
|
|
const torch::autograd::variable_list& grads,
|
|
const std::shared_ptr<torch::autograd::Node>& graphRoot,
|
|
torch::autograd::edge_list& outputEdges,
|
|
bool retainGraph);
|
|
|
|
// Given a pre-populated GraphTask and a root node, compute the backward pass
|
|
// for the autograd graph until the graph task ready queue is empty.
|
|
//
|
|
// This method assumes that the appropriate GraphTask has already been
|
|
// initialized appropriately. It will construct a local ready queue to
|
|
// traverse the GraphTask instead of using the GraphTask embedded
|
|
// cpu_ready_queue, this is because dist engine might run the same GraphTask
|
|
// from different SendFunctions concurrently in different threads. The method
|
|
// will only mark the GraphTask as completed when it needs to, which means it
|
|
// might not mark as completed for every call as dist engine would like to
|
|
// keep the GraphTask alive when it not receives all gradients.
|
|
//
|
|
// When `incrementOutstandingTasks=false`, the function does not increment
|
|
// 'outstanding_tasks_' in the appropriate GraphTask. It is assumed we've
|
|
// already done this before hand for this task (to ensure we don't pre-mark
|
|
// this graph_task as completed). This is useful in the distributed autograd
|
|
// case where we need to increment 'outstanding_tasks_' first to indicate the
|
|
// local autograd engine the graph task is not completed until it receives the
|
|
// signals from other workers over the network.
|
|
//
|
|
// XXX: calling this function assumes that we will have NO GPU nodetasks be
|
|
// executed for the graph_task, the caller of this function need to ensure
|
|
// this otherwise there will be undefined behaviors. A correct way to fix this
|
|
// is to re-design the autograd engine so that GPU worker thread to behave the
|
|
// same as CPU caller thread, record the operation/thread for the device, and
|
|
// reuse it in backward.
|
|
// TODO: 1. Add assert in the dist engine to ensure no GPU NodeTasks during
|
|
// backward
|
|
// 2. properly setup the thread local ready queue to enable reentrant
|
|
// backwards
|
|
void execute_graph_task_until_ready_queue_empty(
|
|
torch::autograd::NodeTask&& node_task,
|
|
bool incrementOutstandingTasks = true);
|
|
|
|
// Run the local autograd engine using the provided graphTask and graphRoot
|
|
// and accumulate the gradients part 'outputEdges' in the provided autograd
|
|
// context.
|
|
c10::intrusive_ptr<c10::ivalue::Future> runEngineAndAccumulateGradients(
|
|
const ContextPtr& autogradContext,
|
|
const std::shared_ptr<torch::autograd::Node>& graphRoot,
|
|
const torch::autograd::edge_list& outputEdges,
|
|
bool incrementOutStandingTasks = true);
|
|
|
|
// Run after the backward pass is done to appropriately cleanup structures.
|
|
void cleanupBackwardPass(const ContextPtr& autogradContext);
|
|
|
|
// Global thread to execute CPU continuations.
|
|
void globalCpuThread(
|
|
const std::shared_ptr<torch::autograd::ReadyQueue>& ready_queue);
|
|
|
|
// Set of autograd context_ids, which we have already initialized for
|
|
// distributed autograd on this node (e.g.: already computed dependencies)
|
|
std::unordered_set<int64_t> initializedContextIds_;
|
|
|
|
mutable std::mutex initializedContextIdsLock_;
|
|
|
|
// Reference to local autograd engine.
|
|
torch::autograd::Engine& engine_;
|
|
|
|
// Ready queue used by the CPU thread in distributed engine.
|
|
// See Note [GPU to CPU continuations]
|
|
std::shared_ptr<torch::autograd::ReadyQueue> global_cpu_ready_queue_;
|
|
|
|
// See Note [GPU to CPU continuations]
|
|
std::thread global_cpu_thread_;
|
|
|
|
friend class BackwardPassCleanupGuard;
|
|
};
|
|
|
|
// Guard to clean up resources once the backward pass is done.
|
|
class BackwardPassCleanupGuard {
|
|
public:
|
|
explicit BackwardPassCleanupGuard(ContextPtr autogradContext)
|
|
: autogradContext_(std::move(autogradContext)) {}
|
|
|
|
~BackwardPassCleanupGuard() {
|
|
DistEngine::getInstance().cleanupBackwardPass(autogradContext_);
|
|
}
|
|
|
|
private:
|
|
ContextPtr autogradContext_;
|
|
};
|
|
|
|
} // namespace torch::distributed::autograd
|