Autograd graphtask trim unnecessary edges (#82544)

### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with  is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: https://github.com/pytorch/pytorch/issues/56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark)

@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc @zasdfgbnm @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544
Approved by: https://github.com/soulitzer
This commit is contained in:
richard
2022-08-11 18:50:09 +00:00
committed by PyTorch MergeBot
parent d438e86719
commit 382ef1fda7
15 changed files with 402 additions and 179 deletions

View File

@ -2,6 +2,7 @@
#include <functorch/csrc/CustomFunction.h>
#include <ATen/ATen.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/graph_task.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/FunctionsManual.h>
@ -192,7 +193,7 @@ variable_list GenericPythonBackward::apply(variable_list&& grads) {
args.emplace_back(saved.unpack(shared_from_this()));
}
if (should_compute_output({ tensors_ix })) {
if (task_should_compute_output({ tensors_ix })) {
auto handle = backward_fn_->typed<custom_function_t>();
auto grad_result = handle.call(args);
grad_inputs = grad_result;

View File

@ -4,6 +4,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <torch/torch.h>
#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <test/cpp/api/support.h>
@ -276,6 +277,102 @@ TEST(CustomAutogradTest, CustomFunction) {
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
}
TEST(CustomAutogradTest, GraphTaskTrimEdges) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
Variable var2,
int mul,
bool needs_input1_grad,
bool needs_input2_grad) {
// setup the expected should and should not compute idx
ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul * var2 + var1 * var2;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Test `needs_input_grad` method is working correctly.
// We have to test this within the backward function.
auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
IndexRange var1_idx = {0, 1};
IndexRange var2_idx = {1, 2};
EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
EXPECT_EQ(
ctx->needs_input_grad({var1_idx, var2_idx}),
needs_input1_grad || needs_input2_grad);
// calculate gradients
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
Variable grad_var1, grad_var2;
if (ctx->needs_input_grad(0)) {
grad_var1 = grad_output[0] + grad_output[0] * var2;
}
if (ctx->needs_input_grad(1)) {
grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
}
variable_list output = {
grad_var1,
grad_var2,
Variable(),
Variable(),
Variable(),
};
return output;
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto go = torch::ones_like(x);
Variable out;
// grad_x
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ false);
auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
// grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ false,
/* needs_input2_grad= */ true);
auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
// grad_x and grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ true);
auto grads = torch::autograd::grad({out}, {x, y}, {go});
ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
}
TEST(CustomAutogradTest, FunctionReturnsInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var1) {

View File

@ -87,7 +87,7 @@ GRAD_INPUT_MASK = CodeTemplate(
DERIVATIVE_SINGLE = CodeTemplate(
"""\
if (should_compute_output({ ${name}_ix })) {
if (task_should_compute_output({ ${name}_ix })) {
auto grad_result = ${derivative};
copy_range(grad_inputs, ${name}_ix, grad_result);
}
@ -96,7 +96,7 @@ if (should_compute_output({ ${name}_ix })) {
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
"""\
if (should_compute_output({ ${name}_ix })) {
if (task_should_compute_output({ ${name}_ix })) {
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
}
"""
@ -104,7 +104,7 @@ DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
DERIVATIVE_MULTI = CodeTemplate(
"""\
if (should_compute_output({ ${idx_ranges} })) {
if (task_should_compute_output({ ${idx_ranges} })) {
${grad_input_mask}
auto grad_result = ${derivative};
${copy_ranges}
@ -673,7 +673,9 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
)
else:
if "grad_input_mask" in formula:
masks = [f"should_compute_output({{ {n}_ix }})," for n in var_names]
masks = [
f"task_should_compute_output({{ {n}_ix }})," for n in var_names
]
grad_input_mask = GRAD_INPUT_MASK.substitute(
masks=masks, n=len(var_names)
)

View File

@ -1,6 +1,12 @@
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/variable.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/ones_like.h>
#endif
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>

View File

@ -509,6 +509,19 @@ variable_list AutogradContext::get_saved_variables() const {
return saved;
}
bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
return ptr->task_should_compute_output(output_edge_index);
}
bool AutogradContext::needs_input_grad(
std::initializer_list<IndexRange> idxs) const {
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
return ptr->task_should_compute_output(idxs);
}
void AutogradContext::mark_dirty(const variable_list& inputs) {
dirty_inputs_.clear();
dirty_inputs_.reserve(inputs.size());

View File

@ -131,6 +131,11 @@ struct TORCH_API AutogradContext {
const std::unordered_set<at::TensorImpl*>& get_and_bump_dirty() const;
const std::unordered_set<at::TensorImpl*>& get_non_differentiable() const;
/// Expose the Node's `task_should_compute_output` method to the cpp
/// custom autograd Function as `needs_input_grad`.
bool needs_input_grad(size_t output_edge_index) const;
bool needs_input_grad(std::initializer_list<IndexRange> idxs) const;
private:
std::unordered_set<at::TensorImpl*> non_differentiable_;
std::unordered_set<at::TensorImpl*> dirty_inputs_;

View File

@ -13,6 +13,12 @@
#include <ATen/Parallel.h>
#include <ATen/detail/CUDAHooksInterface.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/isnan.h>
#endif
#include <c10/core/DeviceGuard.h>
#include <c10/core/Event.h>
#include <c10/core/Stream.h>
@ -368,6 +374,17 @@ void GraphTaskGuard::restore_current_graph_task() {
current_graph_task = std::move(last_graph_task_);
}
// The current graph task's exec_info is being used to trim unnecessary edegs
// during node evaluation, see `Node.task_should_compute_output()` function.
const std::unordered_map<Node*, GraphTask::ExecInfo>*
get_current_graph_task_exec_info() {
return current_graph_task ? &current_graph_task->exec_info_ : nullptr;
}
void add_node_to_current_graph_task_exec_info(Node* fn) {
current_graph_task->exec_info_[fn].needed_ = true;
}
// NOTE: graph_tasks do not necessarily form a stack. Imagine this
// case:
//

View File

@ -10,6 +10,7 @@
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/graph_task.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <torch/csrc/autograd/utils/warnings.h>
@ -35,9 +36,6 @@ struct ReadyQueue;
namespace torch {
namespace autograd {
static constexpr int NO_DEVICE = -2;
static constexpr int CPU_DEVICE = -1;
// Maximum reentrant backward depth before switching to a new thread
// This limit is based on the TSAN's deadlock detector, where it will
// fail if a program hold more than 65 locks in one thread at once.
@ -52,172 +50,6 @@ void validate_outputs(
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask : std::enable_shared_from_this<GraphTask> {
std::atomic<uint64_t> outstanding_tasks_{0};
// Indicates if an error occurred while executing any task. When this is
// true, it signals all threads to stop executing.
std::atomic_bool has_error_{false};
std::atomic_bool future_completed_{false};
// It is safe to read keep_graph_ without synchronization
bool keep_graph_;
// To protect reads/writes to not_ready_, dependencies_, captured_vars_,
// has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
std::mutex mutex_;
std::unordered_map<Node*, InputBuffer> not_ready_;
std::unordered_map<Node*, int> dependencies_;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct ExecInfo {
struct Capture {
Capture(const Capture&) = delete;
Capture(Capture&&) = default;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Capture(int input_idx, int output_idx)
: input_idx_(input_idx), output_idx_(output_idx) {}
int input_idx_; // within Node inputs
int output_idx_; // within the output vector of a GraphTask
// This hook will be executed after a grad is captured. The captured
// grad will be replaced by the return value of the hook.
struct GradCaptureHook {
virtual ~GradCaptureHook() = default;
virtual at::Tensor operator()(const at::Tensor& grad) = 0;
};
// The hooks will be called one by one in the order as they were added.
// The input grad of a hook will be the output of its preceding hook. The
// first hook will take the captured grad as the input. The output of the
// last hook will replace the captured grad.
std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
};
bool should_execute() const {
return needed_ || captures_;
}
bool needed_ = false;
std::unique_ptr<std::vector<Capture>> captures_;
};
// Exec info has a bit complicated semantics. If it's empty, it means the task
// is run in a "default" mode, which means that all next_edges we encounter
// should get executed. If it's not empty, only functions that have an entry
// and this entry has needed == True should be executed. exec_info is only
// empty when the graph is executed via .backward() and the inputs parameter
// is not passed. Otherwise, when executed through .grad(), or when inputs arg
// is specified for .backward(), exec_info will be non-empty.
//
// exec_info_ is safe to read without synchronization
std::unordered_map<Node*, ExecInfo> exec_info_;
// Captures variables are grads captured that we return to the user. After
// execution of the GraphTask is completed, the captured_vars_ are moved
// out of the GraphTask and are no longer valid.
std::vector<Variable> captured_vars_;
// Note: this field is not ready to be used until the proper
// `thread_locals_.set_grad_mode()` call in the constructor.
at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
std::unordered_set<c10::Stream> leaf_streams;
// Per-device current streams of the execute() that called this GraphTask.
// These will be synced with leaf_streams in exec_post_processing.
std::vector<c10::optional<c10::Stream>> caller_current_streams_;
// Collects caller_current_streams_
void stash_current_streams();
void init_to_execute(
Node& graph_root,
const edge_list& outputs,
bool accumulate_grad,
uint64_t min_topo_nr);
// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
// Safe to read owner_ and reentrant_depth_ without synchronizaton
int owner_;
// The number of parent graph tasks for this graph task
const int reentrant_depth_;
bool can_checkpoint() const {
return exec_info_.empty();
}
// check if the GraphTask is completed or not
bool completed();
// mark the graph task as completed and trigger post processing
void mark_as_completed_and_run_post_processing();
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function.
void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function. But doesn't signal completion on
// 'future_result_' right away. The user needs to explicitly mark
// 'future_result_' completed with an appropriate exception.
void set_exception_without_signal(const std::shared_ptr<Node>& fn);
// Whether or not to stop execution for this GraphTask when an error is
// encountered. When set to true, this would cause Engine::execute() to throw
// an exception as soon as the autograd engine receives an exception.
bool exit_on_error_;
// CPU threads are dedicated to processing CPU work for the backward they
// invoked. So any given graph task maintains its own cpu_ready_queue_ where
// you should send work for it to be done. We memoize the cpu_ready_queue_ per
// GraphTask so that we know which ready queue we should push to if we are on
// device thread (i.e. GPU) and but next NodeTask should be run on CPU.
std::shared_ptr<ReadyQueue> cpu_ready_queue_;
// Future representing the completion of the graph task. Notified when all
// tasks are done.
c10::intrusive_ptr<at::ivalue::Future> future_result_;
// Final callbacks installed during execution of this GraphTask
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_. Intentionally no reusing
// mutex_ as the two are protecting different data structures.
std::mutex final_callbacks_lock_;
utils::DelayWarningHandler warning_handler_;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GraphTask(
bool keep_graph,
bool grad_mode,
int reentrant_depth,
std::shared_ptr<ReadyQueue> cpu_ready_queue,
bool exit_on_error = false)
: keep_graph_(keep_graph),
owner_(NO_DEVICE),
reentrant_depth_(reentrant_depth),
exit_on_error_(exit_on_error),
cpu_ready_queue_(std::move(cpu_ready_queue)),
future_result_(c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()))) {
thread_locals_.set_grad_mode(grad_mode);
}
private:
// run GraphTask post processing
void exec_post_processing();
};
// The guard that sets and restores current_graph_task.
class GraphTaskGuard {
public:
explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
~GraphTaskGuard();
void restore_current_graph_task();
private:
std::shared_ptr<GraphTask> last_graph_task_;
};
struct NodeTask {
std::weak_ptr<GraphTask> base_;
std::shared_ptr<Node> fn_;

View File

@ -3,6 +3,7 @@
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/graph_task.h>
#include <torch/csrc/autograd/input_metadata.h>
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/variable.h>
@ -369,6 +370,18 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
/// Returns the name of the dynamic type of the function, for debugging.
virtual std::string name() const;
/// The difference between functions `should_compute_output` and
/// `task_should_compute_output`:
/// - `should_compute_output` should only be used during graph construction
/// and takes into account only requires_grad information
/// - `task_should_compute_output` should only be called during the backward
/// pass (unless called directly through grad_fn) and takes into account the
/// current graph task. Specifically, the autograd engine trims unnecessary
/// edges when `inputs` are specified, and during backward untrimmed nodes
/// left on the graph can/should check `task_should_compute_output` to see if
/// any outgoing edges have been trimmed by the engine. If that is the case,
/// gradient computation wrt those edges can be omitted.
///
/// Returns true if the particular output edge is active, and that particular
/// output of this function should be computed.
bool should_compute_output(size_t output_edge_index) const {
@ -387,6 +400,37 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
});
}
/// Same as the above `should_compute_output` function but will also
/// check whether this edge is needed within the current graph task.
bool task_should_compute_output(size_t output_edge_index) const {
TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
const auto& next = next_edges_[output_edge_index];
if (next.is_valid()) {
const auto exec_info = get_current_graph_task_exec_info();
if (exec_info && !exec_info->empty()) {
auto it = exec_info->find(next.function.get());
if (it == exec_info->end() || !it->second.should_execute()) {
return false; // this edge is not needed for the current graph_task
}
}
return true;
}
return false;
}
/// Returns true if any of the output edges in any of the ranges are active
/// and should be computed in the current graph task.
bool task_should_compute_output(
std::initializer_list<IndexRange> idxs) const {
return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
for (const auto i : c10::irange(range.first, range.second)) {
if (task_should_compute_output(i))
return true;
}
return false;
});
}
/// Returns the `PyObject` stored for this `Node` (for Python
/// interaction).
PyObject* pyobj() const noexcept {

View File

@ -3,6 +3,7 @@
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/graph_task.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/ATen.h>
@ -21,10 +22,10 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grads[0]);
variable_list grad_inputs(2);
if (grad->defined()) {
if (should_compute_output(0)) {
if (task_should_compute_output(0)) {
grad_inputs[0] = at::zeros_like(*grad, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (should_compute_output(1)) {
if (task_should_compute_output(1)) {
// Handle R->C copies without raising a warning
const auto src_type = src_options.dtype().toScalarType();
if (!c10::isComplexType(src_type) && grad->is_complex()) {
@ -85,6 +86,18 @@ auto CopySlices::apply(variable_list&& inputs) -> variable_list {
grad_slice = result.as_strided(view.sizes(), view.strides(), offset);
}
// Adding the missing nodes to the current graph's `exec_info`.
// This is a workaround because the current `GraphTask::init_to_execute`
// does not traverse into CopySlices node.
const auto exec_info = get_current_graph_task_exec_info();
if (exec_info && !exec_info->empty()) {
for (const auto& next : fn->next_edges()) {
if (next.is_valid()) {
add_node_to_current_graph_task_exec_info(next.function.get());
}
}
}
// TODO: We clone grad_slice because we modify it below and "fn" might save
// it for the backward of res. We might be able to avoid the clone() if
// double-backprop is disabled.
@ -92,7 +105,7 @@ auto CopySlices::apply(variable_list&& inputs) -> variable_list {
variable_list grad_inputs(num_outputs());
for (const auto i : c10::irange(res.size())) {
if (should_compute_output(i)) {
if (task_should_compute_output(i)) {
AT_ASSERT(res[i].defined());
if (i == 0) {
grad_slice.copy_(res[i]);

View File

@ -0,0 +1,192 @@
#pragma once
#include <ATen/ThreadLocalState.h>
#include <ATen/core/Tensor.h>
#include <c10/util/ThreadLocal.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/autograd/utils/warnings.h>
#include <vector>
namespace torch {
namespace autograd {
using edge_list = std::vector<Edge>;
struct ReadyQueue;
static constexpr int NO_DEVICE = -2;
static constexpr int CPU_DEVICE = -1;
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask : std::enable_shared_from_this<GraphTask> {
std::atomic<uint64_t> outstanding_tasks_{0};
// Indicates if an error occurred while executing any task. When this is
// true, it signals all threads to stop executing.
std::atomic_bool has_error_{false};
std::atomic_bool future_completed_{false};
// It is safe to read keep_graph_ without synchronization
bool keep_graph_;
// To protect reads/writes to not_ready_, dependencies_, captured_vars_,
// has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
std::mutex mutex_;
std::unordered_map<Node*, InputBuffer> not_ready_;
std::unordered_map<Node*, int> dependencies_;
// Note [Exec info]
// Exec info is created for each GraphTask, which allows filtering paths on
// the graph that are not needed. It has a bit complicated semantics. If it's
// empty, it means the task is run in a "default" mode, which means that all
// next_edges we encounter should get executed. If it's not empty, only
// functions that have an entry and this entry has needed == True should be
// executed. exec_info is only empty when the graph is executed via
// .backward() and the inputs parameter is not passed. Otherwise, when
// executed through .grad(), or when inputs arg is specified for .backward(),
// exec_info will be non-empty.
//
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct ExecInfo {
struct Capture {
Capture(const Capture&) = delete;
Capture(Capture&&) = default;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Capture(int input_idx, int output_idx)
: input_idx_(input_idx), output_idx_(output_idx) {}
int input_idx_; // within Node inputs
int output_idx_; // within the output vector of a GraphTask
// This hook will be executed after a grad is captured. The captured
// grad will be replaced by the return value of the hook.
struct GradCaptureHook {
virtual ~GradCaptureHook() = default;
virtual at::Tensor operator()(const at::Tensor& grad) = 0;
};
// The hooks will be called one by one in the order as they were added.
// The input grad of a hook will be the output of its preceding hook. The
// first hook will take the captured grad as the input. The output of the
// last hook will replace the captured grad.
std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
};
bool should_execute() const {
return needed_ || captures_;
}
bool needed_ = false;
std::unique_ptr<std::vector<Capture>> captures_;
};
// exec_info_ is safe to read without synchronization
std::unordered_map<Node*, ExecInfo> exec_info_;
// Captures variables are grads captured that we return to the user. After
// execution of the GraphTask is completed, the captured_vars_ are moved
// out of the GraphTask and are no longer valid.
std::vector<Variable> captured_vars_;
// Note: this field is not ready to be used until the proper
// `thread_locals_.set_grad_mode()` call in the constructor.
at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
std::unordered_set<c10::Stream> leaf_streams;
// Per-device current streams of the execute() that called this GraphTask.
// These will be synced with leaf_streams in exec_post_processing.
std::vector<c10::optional<c10::Stream>> caller_current_streams_;
// Collects caller_current_streams_
void stash_current_streams();
void init_to_execute(
Node& graph_root,
const edge_list& outputs,
bool accumulate_grad,
uint64_t min_topo_nr);
// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
// Safe to read owner_ and reentrant_depth_ without synchronizaton
int owner_;
// The number of parent graph tasks for this graph task
const int reentrant_depth_;
bool can_checkpoint() const {
return exec_info_.empty();
}
// check if the GraphTask is completed or not
bool completed();
// mark the graph task as completed and trigger post processing
void mark_as_completed_and_run_post_processing();
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function.
void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function. But doesn't signal completion on
// 'future_result_' right away. The user needs to explicitly mark
// 'future_result_' completed with an appropriate exception.
void set_exception_without_signal(const std::shared_ptr<Node>& fn);
// Whether or not to stop execution for this GraphTask when an error is
// encountered. When set to true, this would cause Engine::execute() to throw
// an exception as soon as the autograd engine receives an exception.
bool exit_on_error_;
// CPU threads are dedicated to processing CPU work for the backward they
// invoked. So any given graph task maintains its own cpu_ready_queue_ where
// you should send work for it to be done. We memoize the cpu_ready_queue_ per
// GraphTask so that we know which ready queue we should push to if we are on
// device thread (i.e. GPU) and but next NodeTask should be run on CPU.
std::shared_ptr<ReadyQueue> cpu_ready_queue_;
// Future representing the completion of the graph task. Notified when all
// tasks are done.
c10::intrusive_ptr<at::ivalue::Future> future_result_;
// Final callbacks installed during execution of this GraphTask
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_. Intentionally no reusing
// mutex_ as the two are protecting different data structures.
std::mutex final_callbacks_lock_;
utils::DelayWarningHandler warning_handler_;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GraphTask(
bool keep_graph,
bool grad_mode,
int reentrant_depth,
std::shared_ptr<ReadyQueue> cpu_ready_queue,
bool exit_on_error = false)
: keep_graph_(keep_graph),
owner_(NO_DEVICE),
reentrant_depth_(reentrant_depth),
exit_on_error_(exit_on_error),
cpu_ready_queue_(std::move(cpu_ready_queue)),
future_result_(c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()))) {
thread_locals_.set_grad_mode(grad_mode);
}
private:
// run GraphTask post processing
void exec_post_processing();
};
// The guard that sets and restores current_graph_task.
class GraphTaskGuard {
public:
explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
~GraphTaskGuard();
void restore_current_graph_task();
private:
std::shared_ptr<GraphTask> last_graph_task_;
};
TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>*
get_current_graph_task_exec_info();
void add_node_to_current_graph_task_exec_info(Node* fn);
} // namespace autograd
} // namespace torch

View File

@ -3,6 +3,7 @@
#include <ATen/BatchedTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/TensorOperators.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Event.h>

View File

@ -5,7 +5,6 @@
// values in-place (adding an input twice will accumulate the result).
// This behaviour is needed and used only in backward graphs.
#include <ATen/ATen.h>
#include <memory>
#include <utility>
#include <vector>

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/frontend/tracer.h>
#include <ATen/Backtrace.h>
#include <ATen/ScalarOps.h>
#include <ATen/TracerMode.h>
#include <ATen/core/Dict.h>
#include <ATen/core/functional.h>

View File

@ -375,7 +375,7 @@ struct DifferentiableGraphBackward : public autograd::Node {
private:
void produceOutput(size_t i, at::Tensor output, variable_list& outputs) {
if (should_compute_output(i)) {
if (task_should_compute_output(i)) {
const auto& edge = next_edge(i);
if (output.defined()) {
outputs.emplace_back(std::move(output));