mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: More clang tidy cleanups in `torch/csrc`. This time: 1. `hicpp-use-equals-default` recommends `= default` instead of `{}` for constructors/destructors. This is better practice because it expresses the intent better (https://stackoverflow.com/questions/6502828/what-does-default-mean-after-a-class-function-declaration) 2. `readability-inconsistent-declaration-parameter-name` enforces that parameter names in the declaration match parameter names in the definition. This is just generally useful and can prevent confusion and bugs. Also updated my script a little bit. apaszke ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/9737 Differential Revision: D9069069 Pulled By: goldsborough fbshipit-source-id: f7b3f3a4eb4c9fadc30425a153566d3b613a41ae
55 lines
1.6 KiB
C++
55 lines
1.6 KiB
C++
#pragma once
|
|
|
|
#include "torch/csrc/WindowsTorchApiMacro.h"
|
|
#include "torch/csrc/autograd/variable_version.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <cstdint>
|
|
#include <list>
|
|
#include <memory>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
struct Variable;
|
|
struct Function;
|
|
|
|
TORCH_API extern const char* ERR_BACKWARD_TWICE;
|
|
|
|
/// A snapshot of a variable at a certain version. A `SavedVariable` stores
|
|
/// enough information to reconstruct a variable from a certain point in time.
|
|
class TORCH_API SavedVariable {
|
|
public:
|
|
SavedVariable() = default;
|
|
SavedVariable(const Variable& variable, bool is_output);
|
|
SavedVariable(SavedVariable&&) = default;
|
|
SavedVariable& operator=(SavedVariable&&) = default;
|
|
|
|
/// Reconstructs the saved variable. Pass `saved_for` as the gradient
|
|
/// function if constructing the `SavedVariable` with it would have caused a
|
|
/// circular reference.
|
|
Variable unpack(std::shared_ptr<Function> saved_for = nullptr) const;
|
|
|
|
void reset_data() {
|
|
return data_.reset();
|
|
}
|
|
|
|
private:
|
|
at::Tensor data_;
|
|
|
|
// The gradient function associated with this node. If has_grad_fn
|
|
// is false, then this is a leaf node. Note that the grad_fn is not saved if
|
|
// it would create a circular reference. In that case, the grad_fn must be
|
|
// passed in to the unpack function when reconstructing the Variable.
|
|
std::shared_ptr<Function> grad_fn_;
|
|
std::weak_ptr<Function> grad_accumulator_;
|
|
VariableVersion version_counter_;
|
|
|
|
uint32_t saved_version_ = 0;
|
|
uint32_t output_nr_ = 0;
|
|
bool was_default_constructed_ = true;
|
|
bool requires_grad_ = false;
|
|
bool has_grad_fn_ = false;
|
|
};
|
|
}} // namespace torch::autograd
|