mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 00:54:52 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49734 RFC: https://github.com/pytorch/rfcs/pull/11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D25678797 Pulled By: albanD fbshipit-source-id: 3d58550c11b5f58b9b73fd30596d042b857fb9dd
252 lines
9.2 KiB
C++
252 lines
9.2 KiB
C++
#include <torch/csrc/autograd/custom_function.h>
|
|
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
|
#include <torch/csrc/autograd/autograd.h>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
VariableInfo::VariableInfo(const Variable& var)
|
|
: layout(var.layout())
|
|
, device(var.device())
|
|
, scalar_type(var.scalar_type())
|
|
, size(var.sizes().vec())
|
|
, requires_grad(var.requires_grad()) {
|
|
}
|
|
|
|
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
|
|
return at::zeros(size,
|
|
at::TensorOptions(scalar_type).device(device).layout(layout));
|
|
}
|
|
|
|
variable_list _wrap_outputs(const variable_list &input_vars,
|
|
const std::unordered_set<at::TensorImpl*> &non_differentiable,
|
|
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
|
|
const at::ArrayRef<Variable> raw_outputs,
|
|
const std::shared_ptr<Node> &cdata) {
|
|
|
|
std::unordered_set<at::TensorImpl*> inputs;
|
|
inputs.reserve(input_vars.size());
|
|
for (auto& var : input_vars) {
|
|
inputs.emplace(var.unsafeGetTensorImpl());
|
|
}
|
|
|
|
int num_outputs = raw_outputs.size();
|
|
|
|
// Sets the grad_fn and output_nr of an output Variable.
|
|
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
|
|
bool is_differentiable) {
|
|
if (!is_differentiable) {
|
|
if (!var.requires_grad()) {
|
|
return;
|
|
}
|
|
// Return detached aliases of inputs, instead of changing their requires_grad
|
|
// property.
|
|
if (is_input) {
|
|
var = var.detach();
|
|
} else if (!var.is_view()) {
|
|
var.detach_();
|
|
}
|
|
// If var is a view of one of the inputs of the custom autograd Function,
|
|
// we don't detach it in a no_grad block. This is so that we can mimic the
|
|
// behavior of returning a view from a no_grad block:
|
|
// x = torch.randn(3, requires_grad=True)
|
|
// with torch.no_grad():
|
|
// y = x.view(-1)
|
|
// Here, `y` requires_grad (!).
|
|
} else if (is_modified) {
|
|
if (var.is_leaf() && var.requires_grad()) {
|
|
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
|
|
}
|
|
// No need to mark as modified Tensors that are not inputs.
|
|
if (!is_input) {
|
|
TORCH_WARN("Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
|
|
" is no need to pass it to mark_dirty().");
|
|
}
|
|
// If the input is a view, the rebase will need to rewrite the graph and this only works if we have a single
|
|
// output to this Function.
|
|
TORCH_CHECK(!(var.is_view() && num_outputs > 1), "If your Function modifies inplace an input that is a view"
|
|
" of another Tensor, your Function cannot return more than one Tensor. This is not supported"
|
|
" by the current autograd engine. You should either make sure the input is not a view (using"
|
|
" .clone() for example) or make your Function only return one Tensor (potentially splitting"
|
|
" it into two Functions: one doing the inplace that returns a single Tensor and a second one"
|
|
" that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if"
|
|
" you need help to do this change.");
|
|
|
|
// If the input was modified, transplant the grad_fn in the graph:
|
|
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
|
|
var.mutable_grad().reset();
|
|
impl::clear_hooks(var);
|
|
if (auto grad_acc_fn = impl::try_get_grad_accumulator(var)) {
|
|
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
|
|
grad_acc->variable.reset();
|
|
}
|
|
if (cdata) {
|
|
impl::rebase_history(var, {cdata, output_nr});
|
|
}
|
|
} else if (is_input) {
|
|
// An input has been returned, but it wasn't modified. Return it as a view
|
|
// so that we can attach a new grad_fn to the Variable.
|
|
// Run in no_grad mode to mimic the behavior of the forward.
|
|
{
|
|
AutoGradMode grad_mode(false);
|
|
var = var.view_as(var);
|
|
}
|
|
impl::set_gradient_edge(var, {cdata, output_nr});
|
|
} else if (cdata) {
|
|
impl::set_gradient_edge(var, {cdata, output_nr});
|
|
}
|
|
};
|
|
|
|
std::vector<torch::autograd::Variable> outputs;
|
|
std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
|
|
outputs.reserve(num_outputs);
|
|
int num_diff_outputs = 0;
|
|
|
|
|
|
for (auto i = 0; i < num_outputs; ++i) {
|
|
Variable var = raw_outputs[i];
|
|
|
|
auto out_tensor_impl = raw_outputs[i].unsafeGetTensorImpl();
|
|
bool is_input = inputs.count(out_tensor_impl) > 0;
|
|
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
|
|
bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0
|
|
&& isDifferentiableType(var.scalar_type());
|
|
|
|
if (cdata) {
|
|
auto output_nr = cdata->add_input_metadata(var);
|
|
AT_ASSERT(i == (int)output_nr);
|
|
}
|
|
set_history(var, i, is_input, is_modified, is_differentiable);
|
|
|
|
// For deprecation cycle. Can be removed after 1.6. In the case where we detected a view
|
|
// in no grad mode during the forward, only warn the user (do not change the flag if we
|
|
// return and input that is a view as is).
|
|
// See NOTE [ View + Inplace detection ] for why we replace everything by a warning.
|
|
if (!(is_input && is_modified) && var.is_view()) {
|
|
// NB: is_view() ==> get_autograd_meta()
|
|
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(impl::get_autograd_meta(var));
|
|
diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
|
|
}
|
|
|
|
if (is_differentiable) {
|
|
++num_diff_outputs;
|
|
}
|
|
|
|
outputs_impl.insert(out_tensor_impl);
|
|
outputs.emplace_back(var);
|
|
}
|
|
|
|
// If multiple differentiable outputs are returned, we do not allow views to be modified inplace
|
|
// See NOTE [ View + Inplace detection ] for more details
|
|
if (num_diff_outputs > 1) {
|
|
for (auto& var: outputs) {
|
|
if (var.is_view()) {
|
|
// NB: is_view() ==> get_autograd_meta()
|
|
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(impl::get_autograd_meta(var));
|
|
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
|
|
}
|
|
}
|
|
}
|
|
|
|
// All the modified Tensors must be returned as is for the rewrite to be valid.
|
|
for (auto& dirty_input : dirty_inputs) {
|
|
TORCH_CHECK(outputs_impl.count(dirty_input) > 0,
|
|
"Some elements marked as dirty during the forward method were not returned as output. The"
|
|
" inputs that are modified inplace must all be outputs of the Function.");
|
|
}
|
|
|
|
return outputs;
|
|
}
|
|
|
|
void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) {
|
|
if (!original.options().type_equal(result.options())) {
|
|
std::stringstream ss;
|
|
ss << "hook '" << hook_name << "' has changed the type of value (";
|
|
ss << "was " << original.toString() << " got ";
|
|
ss << result.toString() << ")";
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
|
|
if (original.is_cuda() != result.is_cuda()) {
|
|
std::stringstream ss;
|
|
ss << "hook '" << hook_name << "' has changed the type of value";
|
|
if (original.is_cuda()) {
|
|
ss << " (was CUDA tensor got CPU tensor)";
|
|
} else {
|
|
ss << " (was CPU tensor got CUDA tensor)";
|
|
}
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
|
|
if (original.sizes().vec() != result.sizes().vec()) {
|
|
std::stringstream ss;
|
|
ss << "hook '" << hook_name << "' has changed the size of value";
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
|
|
void AutogradContext::save_for_backward(variable_list to_save) {
|
|
to_save_ = std::move(to_save);
|
|
}
|
|
|
|
// The logic for handling saved variables here is the same as python_function.cpp
|
|
// See _save_variables() and unpack_saved_variables()
|
|
void AutogradContext::save_variables() {
|
|
saved_variables_.clear();
|
|
auto ptr = grad_fn_.lock();
|
|
|
|
for (const auto& var : to_save_) {
|
|
// Allow empty variables to be saved
|
|
if (var.defined()) {
|
|
bool is_output = var.grad_fn().get() == ptr.get();
|
|
saved_variables_.emplace_back(var, is_output);
|
|
} else {
|
|
saved_variables_.emplace_back();
|
|
}
|
|
}
|
|
to_save_.clear();
|
|
}
|
|
|
|
variable_list AutogradContext::get_saved_variables() const {
|
|
TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
|
|
variable_list saved;
|
|
saved.reserve(saved_variables_.size());
|
|
auto ptr = grad_fn_.lock();
|
|
TORCH_INTERNAL_ASSERT(ptr);
|
|
for (auto& var : saved_variables_) {
|
|
saved.push_back(var.unpack(ptr));
|
|
}
|
|
return saved;
|
|
}
|
|
|
|
void AutogradContext::mark_dirty(const variable_list &inputs) {
|
|
dirty_inputs_.clear();
|
|
dirty_inputs_.reserve(inputs.size());
|
|
for(auto& var : inputs) {
|
|
dirty_inputs_.insert(var.unsafeGetTensorImpl());
|
|
}
|
|
}
|
|
|
|
void AutogradContext::mark_non_differentiable(const variable_list &outputs) {
|
|
non_differentiable_.clear();
|
|
non_differentiable_.reserve(outputs.size());
|
|
for(auto& var : outputs) {
|
|
non_differentiable_.insert(var.unsafeGetTensorImpl());
|
|
}
|
|
}
|
|
|
|
void AutogradContext::set_materialize_grads(bool value) {
|
|
materialize_grads_ = value;
|
|
}
|
|
|
|
const std::unordered_set<at::TensorImpl*>& AutogradContext::get_and_bump_dirty() const {
|
|
for (auto& var : dirty_inputs_) {
|
|
var->bump_version();
|
|
}
|
|
return dirty_inputs_;
|
|
}
|
|
|
|
const std::unordered_set<at::TensorImpl*>& AutogradContext::get_non_differentiable() const {
|
|
return non_differentiable_;
|
|
}
|
|
}} // namespace torch::autograd
|