mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Improve Variable interface * Address comments from @apaszke and @colesbury * string ::operator= is not noexcept * Remove ir.h from tracer_state.h to improve build times * Make Variable a struct and pack SavedVariable fields * Implement as_variable_ref * grad_fn_ptr() -> grad_fn_unsafe() * Reduce hackiness of set_type hack * Include variable.h and edge.h in tracer_state.h because it uses them * class Variable -> struct Variable because Windows cant even * Make Variable::output_nr uint32_t instead of int * Add comment about tracing state * Replaced more static_cast<Variable&> and improve docs * Remove SavedVariable destructor and construct members in init list * Clarify docs for Variable * Variable::set_version -> set_version_counter
93 lines
2.5 KiB
C++
93 lines
2.5 KiB
C++
#pragma once
|
|
|
|
#include "torch/csrc/autograd/edge.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
|
|
#include <atomic>
|
|
#include <cstdint>
|
|
#include <list>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch { namespace jit {
|
|
struct Graph;
|
|
struct Value;
|
|
struct VariableFlags;
|
|
}} // namespace torch::jit
|
|
|
|
namespace torch { namespace jit { namespace tracer {
|
|
|
|
using edge_list = std::vector<autograd::Edge>;
|
|
using variable_list = std::vector<autograd::Variable>;
|
|
|
|
// TracingState tracks the necessary state when we are tracing the execution of
|
|
// autograd code; most importantly, it holds a reference to the actual IR
|
|
// graph which we are recording the trace to.
|
|
//
|
|
// The liveness of a TracingState is expected to be a superset of the region
|
|
// of code being traced; in particular, Variables do not keep a TracingState
|
|
// live. Instead, they hold weak pointers to TracingState, to prevent leaks
|
|
// from arising when a variable that participated in a trace outlives the
|
|
// actual trace itself.
|
|
|
|
using io_variable_flags_list = std::vector<
|
|
std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>>;
|
|
|
|
struct TracingState : public std::enable_shared_from_this<TracingState> {
|
|
explicit TracingState(size_t num_stages);
|
|
~TracingState();
|
|
|
|
std::shared_ptr<Graph> graph;
|
|
bool active;
|
|
|
|
// Used to free the Graph as soon as we know this trace will fail
|
|
size_t num_stages;
|
|
std::atomic<size_t> eval_count;
|
|
|
|
// void* is an unsafe TH. NON-OWNING, so it might get invalidated.
|
|
// TODO: Perhaps, turn this into an owning reference. The buffers
|
|
// are persistent, so this won't lead to a leak.
|
|
std::unordered_map<void*, Value*> buffer_map;
|
|
// A pair of (input_flags, output_flags) for each stage
|
|
io_variable_flags_list var_flags;
|
|
std::vector<edge_list> output_edges;
|
|
|
|
std::mutex mutex;
|
|
variable_list inputs; // Used only for the duration of first stage
|
|
|
|
std::unique_lock<std::mutex> lock() {
|
|
return std::unique_lock<std::mutex>(mutex);
|
|
}
|
|
|
|
bool is_expired() const noexcept {
|
|
return !graph;
|
|
}
|
|
|
|
bool is_complete() const;
|
|
void push_scope(const std::string& scope_name);
|
|
void pop_scope();
|
|
};
|
|
|
|
struct ValueTracingStateElem {
|
|
std::weak_ptr<TracingState> state;
|
|
// it's only valid to use this field if !state.exired()
|
|
Value* trace = nullptr;
|
|
|
|
void reset() {
|
|
state.reset();
|
|
trace = nullptr;
|
|
}
|
|
};
|
|
|
|
using ValueTracingState = std::list<ValueTracingStateElem>;
|
|
|
|
struct FunctionTracingState {
|
|
bool in_eval_subgraph = false;
|
|
};
|
|
|
|
}}} // namespace torch::jit::tracer
|