mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is the same as https://github.com/pytorch/pytorch/pull/164467 But it needs to be co-deved due to internal insanity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164736 Approved by: https://github.com/soulitzer
627 lines
24 KiB
C++
627 lines
24 KiB
C++
#include <c10/util/irange.h>
|
|
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
|
#include <torch/csrc/autograd/autograd.h>
|
|
#include <torch/csrc/autograd/custom_function.h>
|
|
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace torch::autograd {
|
|
|
|
// This function has two main goals:
|
|
// 1) Use the user-provided jvp function to populate the outputs' forward
|
|
// gradient 2) Perform error checking to ensure that view and inplace ops are
|
|
// properly handled
|
|
//
|
|
// For 1) we have to:
|
|
// - Create a variable_list of grad_inputs based on the function inputs
|
|
// - Call the user jvp function with these to get the grad_outputs
|
|
// - Set the forward grad field on each output based on these grad_outputs
|
|
//
|
|
// For 2) we want to check the following:
|
|
// - If an output is a view, then the generated forward grad must be a view as
|
|
// well and
|
|
// the output's base's forward grad must be the output's forward grad's base.
|
|
// - If an input was modified inplace (it must be an output as well) we make
|
|
// sure that its
|
|
// forward grad was also modified inplace and already present on the
|
|
// corresponding output.
|
|
static void _process_forward_mode_AD(
|
|
const variable_list& inputs,
|
|
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping,
|
|
const at::ArrayRef<std::optional<Variable>> raw_outputs,
|
|
const optional_variable_list& outputs,
|
|
const std::unordered_set<at::TensorImpl*>& non_differentiable,
|
|
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
|
|
const _jvp_fn_t& jvp_user_function) {
|
|
// TODO handle multiple levels here
|
|
uint64_t level = 0;
|
|
|
|
const auto num_inputs = inputs.size();
|
|
const auto num_outputs = outputs.size();
|
|
|
|
// The tracking info below are used to perform the view and inplace checks.
|
|
// They are lazily initialized to reduce the cost of this function in the
|
|
// common case where the user is not using forward mode AD.
|
|
variable_list input_grads;
|
|
std::vector<int64_t> grad_versions;
|
|
std::vector<at::TensorImpl*> grad_impls;
|
|
std::unordered_map<at::TensorImpl*, size_t> inputs_bases;
|
|
|
|
auto init_tracked_info = [&]() {
|
|
input_grads.resize(num_inputs);
|
|
grad_versions.resize(num_inputs);
|
|
grad_impls.resize(num_inputs);
|
|
|
|
for (const auto i : c10::irange(num_inputs)) {
|
|
const auto& inp = inputs[i];
|
|
if (inp.is_view() && impl::get_view_autograd_meta(inp)->has_fw_view()) {
|
|
inputs_bases.emplace(
|
|
impl::get_view_autograd_meta(inp)
|
|
->get_forward_view()
|
|
.base_.unsafeGetTensorImpl(),
|
|
i);
|
|
} else {
|
|
inputs_bases.emplace(inp.unsafeGetTensorImpl(), i);
|
|
}
|
|
}
|
|
};
|
|
|
|
bool any_input_has_grad = false;
|
|
// Extract the input's forward gradients and record any info we will need
|
|
// later
|
|
for (const auto i : c10::irange(num_inputs)) {
|
|
const auto& inp = inputs[i];
|
|
if (!inp.defined()) {
|
|
continue;
|
|
}
|
|
const auto& fw_grad = inp._fw_grad(level);
|
|
if (fw_grad.defined()) {
|
|
if (!any_input_has_grad) {
|
|
any_input_has_grad = true;
|
|
init_tracked_info();
|
|
}
|
|
input_grads[i] = fw_grad;
|
|
grad_versions[i] = fw_grad._version();
|
|
grad_impls[i] = fw_grad.unsafeGetTensorImpl();
|
|
}
|
|
}
|
|
|
|
// If no input has forward grad, nothing to do here
|
|
if (!any_input_has_grad) {
|
|
return;
|
|
}
|
|
|
|
torch::autograd::variable_list forward_grads;
|
|
{
|
|
at::AutoFwGradMode fw_grad_mode(false);
|
|
forward_grads = jvp_user_function(inputs, std::move(input_grads));
|
|
}
|
|
|
|
const auto num_forward_grads = forward_grads.size();
|
|
// contrary to backward mode, we don't allow returning too many gradients
|
|
TORCH_CHECK(
|
|
num_forward_grads == num_outputs,
|
|
"Function's jvp returned "
|
|
"an invalid number of forward gradients (expected ",
|
|
num_outputs,
|
|
" but got ",
|
|
num_forward_grads,
|
|
")");
|
|
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
if (!raw_outputs[i].has_value()) {
|
|
continue;
|
|
}
|
|
const auto& out =
|
|
outputs[i].has_value() ? outputs[i].value() : at::Tensor();
|
|
auto out_tensor_impl = raw_outputs[i].value().unsafeGetTensorImpl();
|
|
bool is_differentiable =
|
|
(non_differentiable.count(out_tensor_impl) == 0 &&
|
|
isDifferentiableType(raw_outputs[i].value().scalar_type()));
|
|
const auto& out_grad = forward_grads[i];
|
|
if (!out.defined() || !is_differentiable) {
|
|
TORCH_CHECK(
|
|
!out_grad.defined(),
|
|
"Function's jvp returned a gradient at position ",
|
|
i,
|
|
", but "
|
|
" the corresponding forward output is not a differentiable Tensor."
|
|
"You should return None at that position instead.");
|
|
continue;
|
|
}
|
|
|
|
bool is_input = inputs_mapping.count(out_tensor_impl) > 0;
|
|
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
|
|
|
|
if (is_modified) {
|
|
TORCH_CHECK(
|
|
is_input,
|
|
"Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
|
|
" is no need to pass it to mark_dirty().");
|
|
auto inp_idx = inputs_mapping[out_tensor_impl];
|
|
if (grad_impls[inp_idx]) {
|
|
// If there was already a forward grad for that input
|
|
// Just make sure that it is modified inplace and returned as-is
|
|
TORCH_CHECK(
|
|
out_grad._version() != grad_versions[inp_idx],
|
|
"An inplace custom Function is not modifying the "
|
|
"forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp "
|
|
"function must modify the corresponding gradient inplace.")
|
|
TORCH_CHECK(
|
|
out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx],
|
|
"An inplace custom Function is not returning the "
|
|
"forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp "
|
|
"function must modify the gradient inplace and return it as-is.")
|
|
} else {
|
|
// If that Tensor didn't had gradients already, set the newly returned
|
|
// one We could also use inputs[inp_idx] here as it is the same as out
|
|
out._set_fw_grad(out_grad, level, /* is_inplace_op */ true);
|
|
}
|
|
} else {
|
|
// At this point, outputs[i] cannot be one of the input (raw_outputs[i]
|
|
// might be but was changed by the backward code)
|
|
TORCH_INTERNAL_ASSERT(
|
|
inputs_mapping.count(out.unsafeGetTensorImpl()) == 0);
|
|
|
|
if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) {
|
|
// If the output is a view
|
|
const auto& out_view_info =
|
|
impl::get_view_autograd_meta(out)->get_forward_view();
|
|
if (inputs_bases.count(out_view_info.base_.unsafeGetTensorImpl())) {
|
|
// And it is a view of an input (either that input is its base or they
|
|
// have a common base)
|
|
const auto matching_input_idx =
|
|
inputs_bases[out_view_info.base_.unsafeGetTensorImpl()];
|
|
const auto& matching_input = inputs[matching_input_idx];
|
|
|
|
const auto& matching_input_grad = matching_input._fw_grad(level);
|
|
|
|
// If the matching input has a forward grad, the user should have
|
|
// returned a view of that Tensor
|
|
if (matching_input_grad.defined()) {
|
|
TORCH_CHECK(
|
|
out_grad.is_view() &&
|
|
impl::get_view_autograd_meta(out_grad)->has_fw_view(),
|
|
"A custom Function's forward is returning a view (or an input as-is) but the jvp is not "
|
|
"returning a view.");
|
|
const auto& out_grad_base = impl::get_view_autograd_meta(out_grad)
|
|
->get_forward_view()
|
|
.base_;
|
|
if (matching_input_grad.is_view() &&
|
|
impl::get_view_autograd_meta(matching_input_grad)
|
|
->has_fw_view()) {
|
|
// If the matching input's grad is a view, ensure that the
|
|
// out_grad is a view of the same base
|
|
const auto& matching_input_grad_base =
|
|
impl::get_view_autograd_meta(matching_input_grad)
|
|
->get_forward_view()
|
|
.base_;
|
|
TORCH_CHECK(
|
|
matching_input_grad_base.unsafeGetTensorImpl() ==
|
|
out_grad_base.unsafeGetTensorImpl(),
|
|
"A custom Function is returning a view but the jvp is not returning a view of the same base as "
|
|
"the given grad input.");
|
|
} else {
|
|
// If the matching input's grad is not a view, then it must be the
|
|
// output gradient's base
|
|
TORCH_CHECK(
|
|
matching_input_grad.unsafeGetTensorImpl() ==
|
|
out_grad_base.unsafeGetTensorImpl(),
|
|
"A custom Function is returning a view but the jvp is not returning a view of the given grad input.");
|
|
}
|
|
} else {
|
|
// We have a view op where the input didn't have a forward grad but
|
|
// the user returned one for the output To ensure that we maintain
|
|
// the view/inplace constraints, we consider this as an inplace op
|
|
// This case CANNOT happen in codegen as all view ops are mapping
|
|
// from one Tensor to one Tensor and so the output of the view
|
|
// cannot have a forward grad if the base does not.
|
|
out._set_fw_grad(out_grad, level, /* is_inplace_op */ true);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
out._set_fw_grad(out_grad, level, /* is_inplace_op */ false);
|
|
}
|
|
}
|
|
}
|
|
|
|
static at::Tensor _view_as_self_with_no_grad(
|
|
const at::Tensor& self,
|
|
const _view_as_self_fn_t& view_as_self_fn) {
|
|
// This is called below in _process_backward_mode_ad in two places:
|
|
//
|
|
// (1) An input has been returned, but it wasn't modified. Return it as a view
|
|
// so that we can attach a new grad_fn to the Variable.
|
|
// Run in no_grad mode to mimic the behavior of the forward.
|
|
//
|
|
// (2) Though it is not necessary for the purposes of attaching grad_fn, we
|
|
// also call this function when an output is non-differentiable (and does not
|
|
// require grad). to help custom forward AD UX more consistent. We'd like to
|
|
// uniformly say that returning an input as-is is treated as if
|
|
// `self.view_as(self)` were returned for that output.
|
|
//
|
|
// Alternatively, we could have not disabled forward grad while performing
|
|
// this view, but it would mean that the user defined jvp may be silently
|
|
// ignored.
|
|
at::AutoFwGradMode fw_grad_mode(false);
|
|
AutoGradMode grad_mode(false);
|
|
// We thread through this view_as_self_fn lambda so that in the case we are a
|
|
// Python custom function (rather than a cpp one), we can properly call the
|
|
// view_as from python so that torch function logic can still trigger.
|
|
return view_as_self_fn(self);
|
|
}
|
|
|
|
static optional_variable_list _process_backward_mode_ad(
|
|
const std::unordered_map<at::TensorImpl*, size_t>& inputs_mapping,
|
|
const std::unordered_set<at::TensorImpl*>& non_differentiable,
|
|
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
|
|
const at::ArrayRef<std::optional<Variable>> raw_outputs,
|
|
const std::shared_ptr<Node>& cdata,
|
|
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
|
const _view_as_self_fn_t& view_as_self_fn,
|
|
bool pure_view) {
|
|
auto num_outputs = raw_outputs.size();
|
|
|
|
#ifndef STRIP_ERROR_MESSAGES
|
|
const char* error_msg_input_returned_as_is =
|
|
"A input that has been returned as-is as output is being saved for backward. "
|
|
"This is not supported if you override setup_context. You should return and "
|
|
"save a view of the input instead, e.g. with x.view_as(x) or setup ctx inside "
|
|
"the forward function itself.";
|
|
#endif
|
|
|
|
// Sets the grad_fn and output_nr of an output Variable.
|
|
auto set_history = [&](Variable& var,
|
|
uint32_t output_nr,
|
|
bool is_input,
|
|
bool is_modified,
|
|
bool is_differentiable,
|
|
bool is_saved_and_setup_context) {
|
|
if (!is_differentiable) {
|
|
if (!var.requires_grad()) {
|
|
if (is_input && !is_modified) {
|
|
TORCH_CHECK(
|
|
!is_saved_and_setup_context, error_msg_input_returned_as_is)
|
|
var = _view_as_self_with_no_grad(var, view_as_self_fn);
|
|
}
|
|
return;
|
|
}
|
|
// Return detached aliases of inputs, instead of changing their
|
|
// requires_grad property.
|
|
if (is_input) {
|
|
var = var.detach();
|
|
} else if (!var.is_view()) {
|
|
var.detach_();
|
|
}
|
|
// If var is a view of one of the inputs of the custom autograd Function,
|
|
// we don't detach it in a no_grad block. This is so that we can mimic the
|
|
// behavior of returning a view from a no_grad block:
|
|
// x = torch.randn(3, requires_grad=True)
|
|
// with torch.no_grad():
|
|
// y = x.view(-1)
|
|
// Here, `y` requires_grad (!).
|
|
} else if (is_modified) {
|
|
if (var.is_leaf() && var.requires_grad()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"a leaf Variable that requires grad has been used in an in-place operation.");
|
|
}
|
|
// No need to mark as modified Tensors that are not inputs.
|
|
if (!is_input) {
|
|
const char* mark_dirty_error_msg =
|
|
"ctx.mark_dirty() received a tensor that was not an input. "
|
|
"Only input Tensors that have been mutated should be passed to "
|
|
"ctx.mark_dirty().";
|
|
// We reach this path in the view of intermediate case
|
|
TORCH_CHECK(!var.is_view(), mark_dirty_error_msg);
|
|
TORCH_WARN(mark_dirty_error_msg);
|
|
}
|
|
// If the input is a view, the rebase will need to rewrite the graph and
|
|
// this only works if we have a single output to this Function.
|
|
TORCH_CHECK(
|
|
!(var.is_view() && num_outputs > 1),
|
|
"If your Function modifies inplace an input that is a view"
|
|
" of another Tensor, your Function cannot return more than one Tensor. This is not supported"
|
|
" by the current autograd engine. You should either make sure the input is not a view (using"
|
|
" .clone() for example) or make your Function only return one Tensor (potentially splitting"
|
|
" it into two Functions: one doing the inplace that returns a single Tensor and a second one"
|
|
" that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if"
|
|
" you need help to do this change.");
|
|
|
|
// If the input was modified, transplant the grad_fn in the graph:
|
|
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
|
|
var.mutable_grad().reset();
|
|
impl::clear_hooks(var);
|
|
if (auto grad_acc_fn = impl::try_get_grad_accumulator(var)) {
|
|
auto& grad_acc = dynamic_cast<AccumulateGrad&>(*grad_acc_fn);
|
|
grad_acc.variable.reset();
|
|
}
|
|
// This repeats the mutation of leaf variables check already done above
|
|
check_inplace(var, true);
|
|
impl::rebase_history(var, {cdata, output_nr});
|
|
} else if (is_input) {
|
|
TORCH_CHECK(!is_saved_and_setup_context, error_msg_input_returned_as_is)
|
|
var = _view_as_self_with_no_grad(var, view_as_self_fn);
|
|
impl::set_gradient_edge(var, {cdata, output_nr});
|
|
} else {
|
|
impl::set_gradient_edge(var, {cdata, output_nr});
|
|
}
|
|
};
|
|
|
|
optional_variable_list outputs;
|
|
std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
|
|
outputs.reserve(num_outputs);
|
|
int num_diff_outputs = 0;
|
|
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
// We put a undefined_input placeholder for outputs that are not tensor and
|
|
// for when the output tensor is not differentiable (see below)
|
|
if (!raw_outputs[i].has_value()) {
|
|
if (cdata) {
|
|
auto output_nr = cdata->add_input_metadata(Node::undefined_input());
|
|
AT_ASSERT(i == output_nr);
|
|
}
|
|
outputs.emplace_back();
|
|
continue;
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
|
Variable var = raw_outputs[i].value();
|
|
|
|
auto out_tensor_impl = var.unsafeGetTensorImpl();
|
|
bool is_input = inputs_mapping.count(out_tensor_impl) > 0;
|
|
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
|
|
bool is_differentiable = cdata &&
|
|
non_differentiable.count(out_tensor_impl) == 0 &&
|
|
isDifferentiableType(var.scalar_type());
|
|
bool is_saved_and_setup_context =
|
|
to_save_if_setup_context.count(out_tensor_impl) > 0;
|
|
|
|
if (cdata) {
|
|
uint32_t output_nr = 0;
|
|
if (!is_differentiable) {
|
|
output_nr = cdata->add_input_metadata(Node::undefined_input());
|
|
} else {
|
|
output_nr = cdata->add_input_metadata(var);
|
|
}
|
|
AT_ASSERT(i == output_nr);
|
|
}
|
|
set_history(
|
|
var,
|
|
i,
|
|
is_input,
|
|
is_modified,
|
|
is_differentiable,
|
|
is_saved_and_setup_context);
|
|
|
|
// For deprecation cycle. Can be removed after 1.6. In the case where we
|
|
// detected a view in no grad mode during the forward, only warn the user
|
|
// (do not change the flag if we return and input that is a view as is). See
|
|
// NOTE [ View + Inplace detection ] for why we replace everything by a
|
|
// warning.
|
|
if (!(is_input && is_modified) && var.is_view()) {
|
|
// is_view() => diff_view_meta
|
|
auto diff_view_meta = impl::get_view_autograd_meta(var);
|
|
diff_view_meta->set_creation_meta(
|
|
pure_view ? CreationMeta::DEFAULT : CreationMeta::IN_CUSTOM_FUNCTION);
|
|
}
|
|
|
|
if (is_differentiable) {
|
|
++num_diff_outputs;
|
|
}
|
|
|
|
outputs_impl.insert(out_tensor_impl);
|
|
outputs.emplace_back(var);
|
|
}
|
|
|
|
// If multiple differentiable outputs are returned, we do not allow views to
|
|
// be modified inplace See NOTE [ View + Inplace detection ] for more details
|
|
if (num_diff_outputs > 1) {
|
|
for (auto& var : outputs) {
|
|
if (var.has_value()) {
|
|
auto diff_view_meta = impl::get_view_autograd_meta(var.value());
|
|
if (diff_view_meta && diff_view_meta->has_bw_view()) {
|
|
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// All the modified Tensors must be returned as is for the rewrite to be
|
|
// valid.
|
|
for (auto& dirty_input : dirty_inputs) {
|
|
TORCH_CHECK(
|
|
outputs_impl.count(dirty_input) > 0,
|
|
"Some elements marked as dirty during the forward method were not returned as output. The"
|
|
" inputs that are modified inplace must all be outputs of the Function.");
|
|
}
|
|
|
|
return outputs;
|
|
}
|
|
|
|
optional_variable_list _wrap_outputs(
|
|
const variable_list& input_vars,
|
|
const std::unordered_set<at::TensorImpl*>& non_differentiable,
|
|
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
|
|
const at::ArrayRef<std::optional<Variable>> raw_outputs,
|
|
const std::shared_ptr<Node>& cdata,
|
|
const _jvp_fn_t& jvp_user_function,
|
|
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
|
const _view_as_self_fn_t& view_as_self_fn,
|
|
bool pure_view) {
|
|
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
|
|
inputs_mapping.reserve(input_vars.size());
|
|
for (const auto i : c10::irange(input_vars.size())) {
|
|
inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i);
|
|
}
|
|
|
|
// Limit pure views to 1-1 mapping as it is unclear if it is even
|
|
// possible to have a pure view for N-1 or 1-N.
|
|
TORCH_CHECK(
|
|
!pure_view || (input_vars.size() == 1 && raw_outputs.size() == 1),
|
|
"Pure view custom Function can only have one input Tensor and one output Tensor. Open an issue if you need to support more.");
|
|
|
|
auto outputs = _process_backward_mode_ad(
|
|
inputs_mapping,
|
|
non_differentiable,
|
|
dirty_inputs,
|
|
raw_outputs,
|
|
cdata,
|
|
to_save_if_setup_context,
|
|
view_as_self_fn,
|
|
pure_view);
|
|
|
|
// This must happen after the backward processing as we expect the
|
|
// computations happening here to track backward mode gradients.
|
|
_process_forward_mode_AD(
|
|
input_vars,
|
|
std::move(inputs_mapping),
|
|
raw_outputs,
|
|
outputs,
|
|
non_differentiable,
|
|
dirty_inputs,
|
|
jvp_user_function);
|
|
|
|
return outputs;
|
|
}
|
|
|
|
void check_variable_result(
|
|
const at::TensorBase& original,
|
|
const at::TensorBase& result,
|
|
const std::string& hook_name) {
|
|
TORCH_CHECK(
|
|
original.options().type_equal(result.options()),
|
|
"hook '",
|
|
hook_name,
|
|
"' has changed the type of value (was ",
|
|
original.toString(),
|
|
" got ",
|
|
result.toString(),
|
|
")");
|
|
|
|
TORCH_CHECK(
|
|
original.is_cuda() == result.is_cuda(),
|
|
"hook '",
|
|
hook_name,
|
|
"' has changed the type of value (was ",
|
|
original.is_cuda() ? "CUDA tensor" : "CPU tensor",
|
|
" got ",
|
|
result.is_cuda() ? "CUDA tensor" : "CPU tensor",
|
|
")");
|
|
|
|
TORCH_CHECK(
|
|
original.sym_sizes().vec() == result.sym_sizes().vec(),
|
|
"hook '",
|
|
hook_name,
|
|
"' has changed the size of value");
|
|
}
|
|
|
|
AutogradContext::AutogradContext(PackedArgs& packed_args) {
|
|
saved_data = packed_args.unpack_saved_data();
|
|
saved_variables_override_ = packed_args.unpack<variable_list>();
|
|
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
|
materialize_grads_ = packed_args.unpack<bool>();
|
|
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
|
has_freed_buffers_ = packed_args.unpack<bool>();
|
|
needs_input_grad_override_ = packed_args.unpack<std::vector<bool>>();
|
|
}
|
|
|
|
void AutogradContext::save_for_backward(variable_list to_save) {
|
|
to_save_ = std::move(to_save);
|
|
}
|
|
|
|
// The logic for handling saved variables here is the same as
|
|
// python_function.cpp See _save_variables() and unpack_saved_variables()
|
|
void AutogradContext::save_variables() {
|
|
saved_variables_.clear();
|
|
auto ptr = grad_fn_.lock();
|
|
|
|
for (const auto& var : to_save_) {
|
|
// Allow empty variables to be saved
|
|
if (var.defined()) {
|
|
bool is_output = var.grad_fn().get() == ptr.get();
|
|
saved_variables_.emplace_back(var, is_output);
|
|
} else {
|
|
saved_variables_.emplace_back();
|
|
}
|
|
}
|
|
to_save_.clear();
|
|
}
|
|
|
|
variable_list AutogradContext::get_saved_variables() const {
|
|
TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
|
|
if (saved_variables_override_.has_value()) {
|
|
return *saved_variables_override_;
|
|
}
|
|
variable_list saved;
|
|
saved.reserve(saved_variables_.size());
|
|
auto ptr = grad_fn_.lock();
|
|
TORCH_INTERNAL_ASSERT(ptr);
|
|
for (auto& var : saved_variables_) {
|
|
saved.push_back(var.unpack(ptr));
|
|
}
|
|
return saved;
|
|
}
|
|
|
|
bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
|
|
if (needs_input_grad_override_.has_value()) {
|
|
return needs_input_grad_override_.value().at(output_edge_index);
|
|
}
|
|
auto ptr = grad_fn_.lock();
|
|
TORCH_INTERNAL_ASSERT(ptr);
|
|
return ptr->task_should_compute_output(output_edge_index);
|
|
}
|
|
|
|
bool AutogradContext::needs_input_grad(
|
|
std::initializer_list<IndexRange> idxs) const {
|
|
if (needs_input_grad_override_.has_value()) {
|
|
return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
|
|
bool result = false;
|
|
for (const auto i : c10::irange(range.first, range.second)) {
|
|
result |= needs_input_grad_override_.value().at(i);
|
|
}
|
|
return result;
|
|
});
|
|
}
|
|
auto ptr = grad_fn_.lock();
|
|
TORCH_INTERNAL_ASSERT(ptr);
|
|
return ptr->task_should_compute_output(idxs);
|
|
}
|
|
|
|
void AutogradContext::mark_dirty(const variable_list& inputs) {
|
|
dirty_inputs_.clear();
|
|
dirty_inputs_.reserve(inputs.size());
|
|
for (auto& var : inputs) {
|
|
dirty_inputs_.insert(var.unsafeGetTensorImpl());
|
|
}
|
|
}
|
|
|
|
void AutogradContext::mark_non_differentiable(const variable_list& outputs) {
|
|
non_differentiable_.clear();
|
|
non_differentiable_.reserve(outputs.size());
|
|
for (auto& var : outputs) {
|
|
non_differentiable_.insert(var.unsafeGetTensorImpl());
|
|
}
|
|
}
|
|
|
|
void AutogradContext::set_materialize_grads(bool value) {
|
|
materialize_grads_ = value;
|
|
}
|
|
|
|
const std::unordered_set<at::TensorImpl*>& AutogradContext::get_and_bump_dirty()
|
|
const {
|
|
for (auto& var : dirty_inputs_) {
|
|
var->bump_version();
|
|
}
|
|
return dirty_inputs_;
|
|
}
|
|
|
|
const std::unordered_set<at::TensorImpl*>& AutogradContext::
|
|
get_non_differentiable() const {
|
|
return non_differentiable_;
|
|
}
|
|
} // namespace torch::autograd
|