mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Improve Variable interface * Address comments from @apaszke and @colesbury * string ::operator= is not noexcept * Remove ir.h from tracer_state.h to improve build times * Make Variable a struct and pack SavedVariable fields * Implement as_variable_ref * grad_fn_ptr() -> grad_fn_unsafe() * Reduce hackiness of set_type hack * Include variable.h and edge.h in tracer_state.h because it uses them * class Variable -> struct Variable because Windows cant even * Make Variable::output_nr uint32_t instead of int * Add comment about tracing state * Replaced more static_cast<Variable&> and improve docs * Remove SavedVariable destructor and construct members in init list * Clarify docs for Variable * Variable::set_version -> set_version_counter
39 lines
921 B
C++
39 lines
921 B
C++
#include "torch/csrc/jit/tracer_state.h"
|
|
#include "torch/csrc/autograd/edge.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
#include "torch/csrc/jit/ir.h"
|
|
|
|
#include <atomic>
|
|
#include <cstdint>
|
|
#include <list>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch { namespace jit { namespace tracer {
|
|
TracingState::TracingState(size_t num_stages)
|
|
: graph(new Graph()),
|
|
active(false),
|
|
num_stages(num_stages),
|
|
eval_count(0),
|
|
var_flags(num_stages),
|
|
output_edges(num_stages) {}
|
|
|
|
TracingState::~TracingState() = default;
|
|
|
|
bool TracingState::is_complete() const {
|
|
return !is_expired() && graph->stage() == num_stages - 1;
|
|
}
|
|
|
|
void TracingState::push_scope(const std::string& scope_name) {
|
|
graph->push_scope(scope_name);
|
|
}
|
|
|
|
void TracingState::pop_scope() {
|
|
graph->pop_scope();
|
|
}
|
|
}}} // namespace torch::jit::tracer
|