mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Improve Variable interface * Address comments from @apaszke and @colesbury * string ::operator= is not noexcept * Remove ir.h from tracer_state.h to improve build times * Make Variable a struct and pack SavedVariable fields * Implement as_variable_ref * grad_fn_ptr() -> grad_fn_unsafe() * Reduce hackiness of set_type hack * Include variable.h and edge.h in tracer_state.h because it uses them * class Variable -> struct Variable because Windows cant even * Make Variable::output_nr uint32_t instead of int * Add comment about tracing state * Replaced more static_cast<Variable&> and improve docs * Remove SavedVariable destructor and construct members in init list * Clarify docs for Variable * Variable::set_version -> set_version_counter
39 lines
1.2 KiB
C++
39 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <atomic>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
|
|
// Every Variable has a version counter. Version counters are incremented
|
|
// whenever the data or shape of a tensor changes through Variable operations.
|
|
// These are typicallly in-place operations. Version counters are used to
|
|
// detect modifications to saved variables which would result in incorrect
|
|
// gradient calculations. Version counters may be shared between Variables:
|
|
//
|
|
// 1. A view shares the version counter of the base Variable,
|
|
// 2. Detached variables share the version counter of the source,
|
|
// 3. Unpacked saved variables share the version counter of the source.
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
struct VariableVersion {
|
|
public:
|
|
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable
|
|
// leaves it in a persistently undefined state. See
|
|
// https://cplusplus.github.io/LWG/issue2334.
|
|
VariableVersion(uint32_t version = 0)
|
|
: version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {}
|
|
|
|
void bump() noexcept {
|
|
version_block_->fetch_add(1);
|
|
}
|
|
|
|
uint32_t current_version() const noexcept {
|
|
return version_block_->load();
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<std::atomic<uint32_t>> version_block_;
|
|
};
|
|
}} // namespace torch::autograd
|