Files
pytorch/torch/csrc/autograd/variable_version.h
Peter Goldsborough 2d5fbe6e0d Improve Variable interface (#5127)
* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
2018-02-12 23:26:26 -05:00

39 lines
1.2 KiB
C++

#pragma once
#include <atomic>
#include <cstdint>
#include <memory>
// Every Variable has a version counter. Version counters are incremented
// whenever the data or shape of a tensor changes through Variable operations.
// These are typicallly in-place operations. Version counters are used to
// detect modifications to saved variables which would result in incorrect
// gradient calculations. Version counters may be shared between Variables:
//
// 1. A view shares the version counter of the base Variable,
// 2. Detached variables share the version counter of the source,
// 3. Unpacked saved variables share the version counter of the source.
namespace torch { namespace autograd {
struct VariableVersion {
public:
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable
// leaves it in a persistently undefined state. See
// https://cplusplus.github.io/LWG/issue2334.
VariableVersion(uint32_t version = 0)
: version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {}
void bump() noexcept {
version_block_->fetch_add(1);
}
uint32_t current_version() const noexcept {
return version_block_->load();
}
private:
std::shared_ptr<std::atomic<uint32_t>> version_block_;
};
}} // namespace torch::autograd