Skip some nodes during discovery using sequence number (#52180)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/12635

This change will help us speed up autograd's discovery algorithm in cases where we use `.grad` and we try to "unroll" the training loop. For example the example in the issue and also https://github.com/pytorch/pytorch/pull/52180#issuecomment-783400832 observe an unbounded multiple of speed-up.

We do this by adding a new sequence_nr-type numbering: for each node, we maintain the length of the longest path from it to any leaf node. How does this help us speed up discovery (dfs)? Previously the bottleneck was that the dfs that computes which nodes need to be executed always explored every node. With this change, before we run dfs, we first compute the mininum seq_nr among all the nodes passed as the `inputs`. If let this be some number N, intuitively this means that dfs should stay at least N units away from any leaf node. So, if we find ourselves too close to any leaf node, we should stop our search early.

Edit:
After some discussion offline, the plan is:
 - make old sequence_nr a construct of the profiler. This means we can avoid accessing thread local state in cases where the profiler is disabled. Note that we cannot replace sequence_nr as-is because profiler's use-case requires that thread-id + sequence_nr can uniquely identify a given node in order for downstream users/programs to correlate nodes from backward and forward passes. This means we must maintain two sequence_nr's and that we have an extra field in Node.
 - In a future PR, we can potentially remove sequence_nr entirely from the profiler as well, but we avoid doing it now because we haven't measured, and its a larger effort because we'd have to mess around with the dispatcher and profiler

Testing with this [code](https://gist.github.com/kyunghyuncho/5fb9991ce1233f909051854a84b7148e), we see that runtime no longer increases as we iterate.

Before:
```
100: Time taken: 0.47s, loss: 1.1e+06
200: Time taken: 0.064s, loss: 6.5e+05
300: Time taken: 0.088s, loss: 4.4e+05
400: Time taken: 0.1s, loss: 3.2e+05
500: Time taken: 0.12s, loss: 2.5e+05
600: Time taken: 0.15s, loss: 2e+05
700: Time taken: 0.18s, loss: 1.7e+05
800: Time taken: 0.2s, loss: 1.4e+05
900: Time taken: 0.22s, loss: 1.2e+05
1000: Time taken: 0.24s, loss: 1.1e+05
1100: Time taken: 0.27s, loss: 9.3e+04
1200: Time taken: 0.3s, loss: 8.3e+04
1300: Time taken: 0.34s, loss: 7.4e+04
1400: Time taken: 0.36s, loss: 6.7e+04
1500: Time taken: 0.38s, loss: 6.1e+04
1600: Time taken: 0.4s, loss: 5.6e+04
1700: Time taken: 0.42s, loss: 5.1e+04
1800: Time taken: 0.44s, loss: 4.7e+04
1900: Time taken: 0.47s, loss: 4.4e+04
2000: Time taken: 0.5s, loss: 4.1e+04
```
After:
```
100: Time taken: 0.49s, loss: 1.2e+06
200: Time taken: 0.031s, loss: 6.9e+05
300: Time taken: 0.031s, loss: 4.6e+05
400: Time taken: 0.031s, loss: 3.3e+05
500: Time taken: 0.031s, loss: 2.6e+05
600: Time taken: 0.031s, loss: 2.1e+05
700: Time taken: 0.031s, loss: 1.7e+05
800: Time taken: 0.031s, loss: 1.4e+05
900: Time taken: 0.031s, loss: 1.2e+05
1000: Time taken: 0.031s, loss: 1.1e+05
1100: Time taken: 0.031s, loss: 9.6e+04
1200: Time taken: 0.031s, loss: 8.6e+04
1300: Time taken: 0.031s, loss: 7.7e+04
1400: Time taken: 0.031s, loss: 7e+04
1500: Time taken: 0.031s, loss: 6.3e+04
1600: Time taken: 0.031s, loss: 5.8e+04
1700: Time taken: 0.031s, loss: 5.3e+04
1800: Time taken: 0.031s, loss: 4.9e+04
1900: Time taken: 0.031s, loss: 4.5e+04
2000: Time taken: 0.032s, loss: 4.2e+04

```
Testing w/ small graph to check for regression:
```
import torch
from torch.utils.benchmark import Timer

setup="""
a = torch.rand((2, 2), requires_grad=True)
b = torch.rand((2, 2), requires_grad=True)
gradient = torch.ones(2, 2)
"""

stmt="""
torch.autograd.grad(a*b, [a, b], gradient)
"""

timer = Timer(stmt, setup)

print(timer.timeit(10000))
print(timer.collect_callgrind(100))
```
Result: there doesn't seem to be any significant regression
```
Time before: 12.74 us
Time after: 13.12 us
Instruction count before:
                           All          Noisy symbols removed
    Instructions:      8078960                    8000882
    Baseline:             4226                       3838
Instruction count after:
                           All          Noisy symbols removed
    Instructions:      8091846                    8017940
    Baseline:             4336                       3838
100 runs per measurement, 1 thread
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52180

Reviewed By: gchanan, zhangguanheng66

Differential Revision: D26794387

Pulled By: soulitzer

fbshipit-source-id: c00d387a29f151109c33dc6f1b56a8f275cdec58
This commit is contained in:
Jeffrey Wan
2021-03-04 16:08:56 -08:00
committed by Facebook GitHub Bot
parent 85109ce427
commit 4739d15a67
6 changed files with 129 additions and 24 deletions

View File

@ -838,9 +838,21 @@ void Engine::evaluate_function(
}
}
/* Computes the number of dependencies for each function which requires grad */
auto Engine::compute_dependencies(Node* root, GraphTask& task) -> void {
// Just to make sure that they will never be added to the queue again
inline static uint64_t compute_min_topological_nr(const edge_list& outputs) {
// Computes the mininum topological number among all the outputs
if (outputs.empty()) {
return 0;
}
auto min_topo_nr = std::numeric_limits<uint64_t>::max();
for (auto & output_edge : outputs) {
auto topo_nr = output_edge.function.get()->topological_nr();
min_topo_nr = (min_topo_nr < topo_nr) ? min_topo_nr : topo_nr;
}
return min_topo_nr;
}
auto Engine::compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr) -> void {
// Computes the number of dependencies for each function which requires grad
std::unordered_set<Node*> seen;
std::vector<Node*> queue { root };
@ -849,6 +861,9 @@ auto Engine::compute_dependencies(Node* root, GraphTask& task) -> void {
auto& dependencies = task.dependencies_;
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
if (fn->topological_nr() < min_topo_nr) {
continue;
}
for (const auto& edge : fn->next_edges()) {
if (auto next_ptr = edge.function.get()) {
dependencies[next_ptr] += 1;
@ -870,7 +885,7 @@ auto Engine::execute(const edge_list& roots,
return msg;
});
// A frech first time Engine::execute call should start on the CPU device, initialize
// A fresh first time Engine::execute call should start on the CPU device, initialize
// a new thread local ready queue on CPU or reuse the existing one (if there is one
// allocated already, i.e. consecutive backward calls, re-entrant backward calls),
// then memoize the local_ready_queue in GraphTask
@ -889,13 +904,15 @@ auto Engine::execute(const edge_list& roots,
roots.at(0).function :
std::make_shared<GraphRoot>(roots, inputs);
// Now compute the dependencies for all executable functions and queue the root
compute_dependencies(graph_root.get(), *graph_task);
auto min_topo_nr = compute_min_topological_nr(outputs);
// Now compute the dependencies for all executable functions
compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);
if (!outputs.empty()) {
graph_task->init_to_execute(*graph_root, outputs, accumulate_grad);
graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr);
}
// Queue the root
if (skip_dummy_node) {
InputBuffer input_buffer(roots.at(0).function->num_inputs());
auto input = inputs.at(0);
@ -1130,7 +1147,7 @@ void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
thread_pool_shared_->work_.notify_one();
}
void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad) {
void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad, uint64_t min_topo_nr) {
// Populates exec_info so nodes that should be executed have `exec_info[node].needed_ = true`
// Only nodes that have a path to any edge in `outputs` should be executed.
// The code below populates exec_info using recursion, but the actual code does this
@ -1217,6 +1234,11 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool
if (child_fn) {
// (2) next child exists but has not been seen
if (child_fn->topological_nr() < min_topo_nr) {
// child created before the first output means this child cannot have
// an edge to output
continue;
}
stack.emplace_back(child_fn);
} else {
// (3) no next child exists for `fn` means its `needed` has already been

View File

@ -112,7 +112,7 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {
std::unordered_set<c10::Stream> leaf_streams;
void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad);
void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad, uint64_t min_topo_nr);
// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
@ -332,7 +332,7 @@ struct TORCH_API Engine {
protected:
Engine();
void compute_dependencies(Node* root, GraphTask& task);
void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr);
// initialize the thread local ready queue with the ready queue that is created
// elsewhere (i.e. thread_init, Engine::execute, etc), or create a new

View File

@ -94,17 +94,21 @@ class NodeGuard {
// sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B`
// are created in one thread and `C` is created in a new thread, there are *no
// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`.
// See NOTE [ Sequence Number] for more details on the usages of sequence number.
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct TORCH_API Node : std::enable_shared_from_this<Node> {
public:
/// Construct a new `Node` with the given `next_edges`. `sequence_nr` is
/// a (currently THE) hint to prioritization in the backward() pass, with
/// higher sequence numbers prioritized before lower sequence numbers.
/// Construct a new `Node` with the given `next_edges`
explicit Node(
uint64_t sequence_nr,
edge_list&& next_edges = edge_list())
: sequence_nr_(sequence_nr),
next_edges_(std::move(next_edges)) {
for (const Edge& edge: next_edges_) {
update_topological_nr(edge);
}
if (AnomalyMode::is_enabled()) {
metadata()->store_stack();
@ -116,12 +120,15 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
}
if (profiler::profilerEnabled()) {
// If profiler is enabled, thread_id is stored.
// See NOTE [ Sequence Numbers ]
thread_id_ = at::RecordFunction::currentThreadId();
}
}
explicit Node(edge_list&& next_edges = edge_list())
: Node(at::sequence_number::get_and_increment(), std::move(next_edges)) {}
: Node(/*sequence_nr=*/at::sequence_number::get_and_increment(),
std::move(next_edges)) {}
/// Nodes are neither copyable nor moveable.
Node(const Node& other) = delete;
@ -140,7 +147,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
bool pre_sampled = false;
if (at::shouldRunRecordFunction(&pre_sampled)) {
// Using RecordFunction to trogger observers in the backward pass
// Using RecordFunction to trigger observers in the backward pass
at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION, pre_sampled);
if (guard.isActive()) {
// Using sequence number and thread id to correlate with
@ -226,20 +233,39 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// Outputs ("Next Edges")
const Edge& next_edge(size_t index) const noexcept {
return next_edges_[index];
void update_topological_nr(const Edge& edge) {
TORCH_INTERNAL_ASSERT(!has_parent_,
"Cannot update a node's topological_nr after it already has a parent."
" If we allow this, we can no longer guarantee that a parent's"
" topo_nr is always greater than those of all its children")
Node* node = edge.function.get();
if (node) {
auto topo_nr = node->topological_nr();
if (topological_nr_ <= topo_nr) {
topological_nr_ = topo_nr + 1;
}
}
}
void set_next_edge(size_t index, Edge edge) {
update_topological_nr(edge);
next_edges_[index] = std::move(edge);
}
void add_next_edge(Edge edge) {
update_topological_nr(edge);
next_edges_.push_back(std::move(edge));
}
void set_next_edges(edge_list&& next_edges) {
next_edges_ = std::move(next_edges);
for(const auto& next_edge : next_edges_) {
update_topological_nr(next_edge);
}
}
const Edge& next_edge(size_t index) const noexcept {
return next_edges_[index];
}
const edge_list& next_edges() const noexcept {
@ -257,11 +283,60 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// Miscellaneous Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// The sequence number of this `Node`.
/// NOTE [ Sequence Number]
///
/// The sequence_nr has two main usages in autograd:
///
/// 1) Helps determine the node's execution priority in the engine.
/// All else being equal, nodes with higher priority numbers are executed first.
/// Thus, nodes corresponding to ops executed later are the first to be executed in
/// the backward pass. One caveat is that we prioritize AccumulateGrad nodes by
/// explicitly setting its sequence_nr to be UINT64_MAX.
/// 2) The sequence number of this `Node` is paired with with thread_id it was created in
/// as a unique identifier by the profiler to annotate recorded events.
/// The purpose of this is to help users (and possibly programs) interpreting the profiler's
/// output to correlate backward nodes with its forward ops.
/// We need both sequence_nr and thread_id to identify a node because sequence_nr is
/// thread_local, i.e., starts counting up from zero in a new thread
uint64_t sequence_nr() const noexcept {
return sequence_nr_;
}
// NOTE [ Topological Number ]
//
// topological_nr is used to prune branches in the DAG during autograd discovery as
// maintaining topological_nr helps us check in O(1) if there does NOT exist
// a directed path between two nodes.
//
// The topological order number of this `Node` representing the length of the
// longest possible path from this Node to any leaf node. If you are leaf node,
// aka AccumulateGrad, this will be zero. This value has the property that
// For every pair of nodes X, Y in G, existence of a directed path from X to Y
// implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so we
// cannot prove existence of a path from X to Y, only non-existence.
//
// One assumption we make when using topo_nr is that once a node
// has been used, i.e., has a parent node, its own topo_nr does not change
// we have added some checks with the `has_parent_` field to enforce this.
//
// What NOT to do:
//
// 1) 2 -> 1 -> 0 In this diagram we label nodes with their topo_nr.
// 2 -> 1 -> 0 We have two simple graphs that can each arise from
// `t.exp().exp()`, for example.
// 2) 2 -> 1 -> 0
// /
// 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 already
// has a parent.
// 3) 2 -> 1 -> 0
// /
// 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3!
//
uint64_t topological_nr() const noexcept {
has_parent_ = true;
return topological_nr_;
}
// assigning a node as a parent to this node
void assign_parent();
@ -387,10 +462,18 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
/// Calls `apply()`, but instruments it with tracing machinery.
variable_list traced_apply(variable_list inputs);
// Since `Node`s are neither copyable nor moveable, we can have const
// fields.
// Sequence number used to correlate backward nodes with forward ops in the
// profiler and provide determinisim in the engine.
const uint64_t sequence_nr_;
// See NOTE [ Topological Number ]
uint64_t topological_nr_ = 0;
// Tracks whether this node has been added as the next_edge of another node
// via set_next_edge(s), which always calls topological_nr() of all its children
// See NOTE [ Topological Number ] for why we need this.
mutable bool has_parent_ = false;
// Id of the thread that created the instance
uint64_t thread_id_ = 0;

View File

@ -17,8 +17,8 @@ namespace torch { namespace autograd {
// AccumulateGrad sets sequence_nr to the max value so it's always called
// ASAP during backwards.
AccumulateGrad::AccumulateGrad(Variable variable_)
: Node(/*sequence_nr=*/UINT64_MAX)
, variable(std::move(variable_)) {
: Node(/*sequence_nr=*/UINT64_MAX),
variable(std::move(variable_)) {
add_input_metadata(variable);
}

View File

@ -68,7 +68,7 @@ struct TORCH_API UndefinedGradBackward : public Node {
struct TORCH_API GraphRoot : public Node {
GraphRoot(edge_list functions, variable_list inputs)
: Node(std::move(functions)),
outputs(std::move(inputs)) {
outputs(std::move(inputs)) {
// Ensures calls to stream() on a GraphRoot instance reflect current stream(s)
// on devices of root grad tensors at the time the instance is constructed.
for (const auto& t : outputs) {

View File

@ -279,7 +279,7 @@ void DistEngine::computeDependencies(
// Create a dummy GraphRoot and run init_to_execute with it.
GraphRoot dummyRoot(edges, {});
graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false);
graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0);
for (auto& mapEntry : graphTask->exec_info_) {
auto& execInfo = mapEntry.second;
if (!execInfo.captures_) {