Files
pytorch/torch/csrc/jit/tracer_state.cpp
Peter Goldsborough 2d5fbe6e0d Improve Variable interface (#5127)
* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
2018-02-12 23:26:26 -05:00

39 lines
921 B
C++

#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/ir.h"
#include <atomic>
#include <cstdint>
#include <list>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch { namespace jit { namespace tracer {
TracingState::TracingState(size_t num_stages)
: graph(new Graph()),
active(false),
num_stages(num_stages),
eval_count(0),
var_flags(num_stages),
output_edges(num_stages) {}
TracingState::~TracingState() = default;
bool TracingState::is_complete() const {
return !is_expired() && graph->stage() == num_stages - 1;
}
void TracingState::push_scope(const std::string& scope_name) {
graph->push_scope(scope_name);
}
void TracingState::pop_scope() {
graph->pop_scope();
}
}}} // namespace torch::jit::tracer