mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156318 Approved by: https://github.com/albanD
1268 lines
44 KiB
C++
1268 lines
44 KiB
C++
#include <torch/csrc/jit/runtime/interpreter.h>
|
|
|
|
#include <ATen/Parallel.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/record_function.h>
|
|
#include <c10/core/thread_pool.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/autograd/edge.h>
|
|
#include <torch/csrc/autograd/grad_mode.h>
|
|
#include <torch/csrc/autograd/profiler.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/ir/constants.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
|
#include <torch/csrc/jit/runtime/exception_message.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
#include <torch/csrc/jit/runtime/interpreter/code_impl.h>
|
|
#include <torch/csrc/jit/runtime/interpreter/frame.h>
|
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
|
#include <torch/csrc/jit/runtime/script_profile.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
#include <torch/csrc/utils/cpp_stacktraces.h>
|
|
#include <string>
|
|
|
|
#ifdef USE_RPC
|
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
using torch::distributed::autograd::DistAutogradContainer;
|
|
#endif
|
|
|
|
#include <exception>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <ostream>
|
|
#include <stdexcept>
|
|
#include <typeinfo>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
// clang-format off
|
|
C10_DEFINE_bool(
|
|
torch_jit_enable_rethrow_caught_exception,
|
|
false,
|
|
"enable rethrowing caught exception")
|
|
|
|
C10_DEFINE_bool(
|
|
torch_jit_enable_expanded_stacks,
|
|
false,
|
|
"When true we will attempts to pre-expand node stacks and cache expanded stacks.")
|
|
|
|
C10_DEFINE_bool(
|
|
torch_jit_expanded_stacks_mangled,
|
|
false,
|
|
"When true pre-expanded stacks will use mangled names.")
|
|
|
|
namespace torch::jit {
|
|
|
|
using CodeImpl = interpreter::CodeImpl;
|
|
|
|
// Before we translate to interpreter instructions, we do
|
|
// some preprocessing of the graph to turn it into a form that is closer
|
|
// to what the instructions will look like.
|
|
// In particular we:
|
|
// * Computes whether a input to a node is the last use, so we can issue MOVE
|
|
// rather than LOAD instructions.
|
|
// * Drop nodes are inserted for any node that is unused to create a dummy use
|
|
// that will cause the interpreter to free the node.
|
|
// A drop node just pops its input off the stack to ensure the interpreter
|
|
// releases references to nodes that are never used. Drop nodes are also
|
|
// inserted when the last use of a node is in some conditionally run control
|
|
// flow (e.g. one side of an If) and the interpreter must free the node only
|
|
// after the control flow has reconverged
|
|
// Outputs are:
|
|
// * graph - the post processed copy of g
|
|
// * move_flags[n] - a list of booleans, one for each input,
|
|
// indicating whether this is the last use of the value. The interpreter
|
|
// should generate a move rather than a copy in this case.
|
|
|
|
TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) {
|
|
if (!t.defined()) {
|
|
return TensorType::get()->withUndefined();
|
|
}
|
|
auto r = TensorType::create(t);
|
|
if (!at::GradMode::is_enabled()) {
|
|
return r->withRequiresGrad(false);
|
|
}
|
|
return r;
|
|
}
|
|
|
|
namespace {
|
|
inline int64_t getDistAutogradContextId() {
|
|
#ifdef USE_RPC
|
|
return DistAutogradContainer::currentContextId();
|
|
#else
|
|
return 0;
|
|
#endif
|
|
}
|
|
} // namespace
|
|
|
|
static thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr;
|
|
struct TLSCurrentInterpreterGuard {
|
|
TLSCurrentInterpreterGuard(InterpreterStateImpl* state)
|
|
: prev_state_(tls_int_state_ptr_) {
|
|
tls_int_state_ptr_ = state;
|
|
}
|
|
|
|
~TLSCurrentInterpreterGuard() {
|
|
tls_int_state_ptr_ = prev_state_;
|
|
}
|
|
|
|
private:
|
|
InterpreterStateImpl* prev_state_;
|
|
};
|
|
|
|
bool in_torchscript_runtime() {
|
|
return tls_int_state_ptr_ != nullptr;
|
|
}
|
|
|
|
// InterpreterState state that and used to compute a Code
|
|
struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
|
InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
|
|
: taskLauncher_(std::move(taskLauncher)) {
|
|
enterFrame(code, 0);
|
|
}
|
|
|
|
private:
|
|
using Frame = torch::jit::interpreter::Frame;
|
|
struct WarnedNodes {
|
|
public:
|
|
// Inserts idx into warned_nodes_, returns a boolean indicates whether
|
|
// insertion actually happened (idx wasn't originally in the set).
|
|
bool insert(int32_t idx) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
return warned_nodes_.insert(idx).second;
|
|
}
|
|
|
|
private:
|
|
std::mutex mutex_;
|
|
std::unordered_set<int32_t> warned_nodes_;
|
|
};
|
|
|
|
WarnedNodes warned_nodes_;
|
|
|
|
// if we need to suspend, where do we reset the stack?
|
|
// answer: to where it was when we were called, not
|
|
// including any inputs to this function
|
|
int64_t stack_start_ = -1;
|
|
c10::intrusive_ptr<Future> future_;
|
|
TaskLauncher taskLauncher_;
|
|
|
|
// this holds all the tensors for this interpreter run
|
|
// we don't bother minimizing the size of this vector, since the extra
|
|
// memory used by the pointers in this will be small
|
|
// instead we are very aggressive about releasing tensors when they become
|
|
// dead to make sure memory management happens efficiently. We optimize for
|
|
// the case where derivatives are run with retain_graph=False in the case
|
|
// where it is true, then the interpreter and this array get copied if this
|
|
// every becomes a bottleneck then we _should_ consider minimizing the total
|
|
// number or register
|
|
std::vector<IValue> registers;
|
|
|
|
// A stack of objects that have been __enter__'d.
|
|
std::vector<IValue> entered_objects;
|
|
|
|
std::vector<Frame> frames;
|
|
|
|
c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
|
|
c10::raw::intrusive_ptr::incref(this);
|
|
return c10::intrusive_ptr<InterpreterStateImpl>::reclaim(this);
|
|
}
|
|
|
|
void enterFrame(const Code& code, size_t base_pointer) {
|
|
frames.emplace_back(Frame{code.pImpl, 0, base_pointer, std::nullopt});
|
|
registers.resize(registers.size() + code.pImpl->register_size_);
|
|
}
|
|
|
|
void leaveFrame() {
|
|
registers.resize(registers.size() - frames.back().function->register_size_);
|
|
frames.pop_back();
|
|
}
|
|
|
|
void callFunction(
|
|
Function& f,
|
|
Stack& stack,
|
|
std::optional<size_t> bailOut = std::nullopt,
|
|
bool next = true) {
|
|
bool newFrame = f.call(stack, bailOut, [&](const Code& code) {
|
|
enterFrame(code, stack.size() - code.num_inputs());
|
|
checkAndStartRecordFunction(frames.back(), stack);
|
|
});
|
|
if (next) {
|
|
(frames.rbegin() + (newFrame ? 1 : 0))->pc++;
|
|
}
|
|
}
|
|
|
|
// relative to the end of the register list so that when we call
|
|
// functions we are referring to the registers of the currently executing
|
|
// function.
|
|
IValue& reg(size_t reg) {
|
|
return *(registers.end() - reg);
|
|
}
|
|
|
|
void dump(std::ostream& out, const Stack& stack) const {
|
|
out << "Stack:\n";
|
|
for (const auto& val : stack) {
|
|
out << val;
|
|
out << "\n";
|
|
}
|
|
}
|
|
|
|
class StackSizeDidntChangeGuard {
|
|
public:
|
|
StackSizeDidntChangeGuard(const StackSizeDidntChangeGuard&) = delete;
|
|
StackSizeDidntChangeGuard(StackSizeDidntChangeGuard&&) = delete;
|
|
StackSizeDidntChangeGuard& operator=(const StackSizeDidntChangeGuard&) =
|
|
delete;
|
|
StackSizeDidntChangeGuard& operator=(StackSizeDidntChangeGuard&&) = delete;
|
|
|
|
StackSizeDidntChangeGuard(
|
|
const Frame& frame,
|
|
const torch::jit::Stack& stack,
|
|
const Instruction& inst)
|
|
: frame_(frame), stack_(stack), instX_(inst.X) {
|
|
// portable maybe_unused attribute.
|
|
(void)frame_;
|
|
(void)stack_;
|
|
(void)instX_;
|
|
(void)initialSize_;
|
|
}
|
|
|
|
void callAssert() const {
|
|
#ifndef NDEBUG
|
|
frame_.function->assert_stack_size(instX_, initialSize_, stack_.size());
|
|
#endif
|
|
}
|
|
|
|
private:
|
|
const Frame& frame_;
|
|
const torch::jit::Stack& stack_;
|
|
std::uint32_t instX_;
|
|
std::size_t initialSize_{stack_.size()};
|
|
};
|
|
|
|
struct [[maybe_unused]] DoNothing {};
|
|
|
|
#if defined(__GNUC__) || defined(__clang__)
|
|
#define JIT_USE_COMPUTED_GOTO
|
|
#endif
|
|
// Primitives for making interpreter internal state transitions.
|
|
// We maintain two local variables as the internal interpreter state:
|
|
// `frame` will be the current frame that the interpreter operators on.
|
|
// `inst` will the current instruction pointed to by program counter.
|
|
//
|
|
// Instruction blocks should be always declared through `INST` macro and
|
|
// the instruction body should always start with a `instGuard()` declaration.
|
|
// Also blocks should be ended properly with either `INST_NEXT` (for going
|
|
// to the next instruction), or `INST_DISPATCH` (for jumping to a computed
|
|
// position using `instFetch`).
|
|
#if defined(JIT_USE_COMPUTED_GOTO)
|
|
#define INST(NAME) \
|
|
NAME: \
|
|
label_##NAME
|
|
#define INST_DISPATCH goto* dispatch_table[inst.op]
|
|
#else
|
|
#define INST(NAME) NAME
|
|
#define INST_DISPATCH break
|
|
#endif
|
|
#define INST_NEXT \
|
|
inst = instFetch(1); \
|
|
INST_DISPATCH
|
|
|
|
template <bool EnableProfiling>
|
|
bool runTemplate(Stack& stack) {
|
|
// if we have never run before, then we might have to return the
|
|
// stack when we suspend, record where it starts so we return the right
|
|
// stack
|
|
if (stack_start_ == -1) {
|
|
TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
|
|
stack_start_ = stack.size() - frames.back().function->n_inputs;
|
|
} else {
|
|
// during restarts, all of the stack is always our own, so we leave
|
|
// nothing
|
|
stack_start_ = 0;
|
|
}
|
|
|
|
TLSCurrentInterpreterGuard g(this);
|
|
if (frames.back().pc == 0 && stack_start_ == 0) {
|
|
checkAndStartRecordFunction(frames.back(), stack);
|
|
}
|
|
|
|
#if defined(JIT_USE_COMPUTED_GOTO)
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
|
|
static void* dispatch_table[] = {
|
|
#define DISPATCH_TABLE_ENTRY(op, _) &&label_##op,
|
|
FORALL_OPCODES(DISPATCH_TABLE_ENTRY)
|
|
#undef DISPATCH_TABLE_ENTRY
|
|
};
|
|
#endif
|
|
|
|
try {
|
|
while (true) {
|
|
Frame& frame = frames.back();
|
|
|
|
auto instFetch = [&](auto x) {
|
|
return frame.function->instructions_[frame.pc += x];
|
|
};
|
|
|
|
auto instGuard = [&] {
|
|
if constexpr (!EnableProfiling) {
|
|
return DoNothing{};
|
|
} else {
|
|
return profiling::InstructionSpan{
|
|
*frame.function->instructions_source()[frame.pc]};
|
|
}
|
|
};
|
|
|
|
Instruction inst = instFetch(0);
|
|
|
|
auto stackSizeAssertGuard = [&] {
|
|
return StackSizeDidntChangeGuard{frame, stack, inst};
|
|
};
|
|
|
|
switch (inst.op) {
|
|
case INST(ENTER): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
const auto& obj = peek(stack, 0, 1);
|
|
TORCH_INTERNAL_ASSERT(obj.isObject());
|
|
entered_objects.push_back(obj);
|
|
}
|
|
INST_NEXT;
|
|
case INST(EXIT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
auto obj = entered_objects.back().toObject();
|
|
auto& f = obj->type()->getMethod("__exit__");
|
|
push(stack, std::move(obj));
|
|
entered_objects.pop_back();
|
|
push(stack, IValue());
|
|
push(stack, IValue());
|
|
push(stack, IValue());
|
|
callFunction(f, stack);
|
|
continue;
|
|
}
|
|
case INST(OP): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
auto stackSizeGuard = stackSizeAssertGuard();
|
|
frame.function->operator_table_[inst.X](stack);
|
|
stackSizeGuard.callAssert();
|
|
}
|
|
INST_NEXT;
|
|
case INST(OPN): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
stack.emplace_back(inst.N);
|
|
auto stackSizeGuard = stackSizeAssertGuard();
|
|
frame.function->operator_table_[inst.X](stack);
|
|
stackSizeGuard.callAssert();
|
|
}
|
|
INST_NEXT;
|
|
case INST(LOAD): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
stack.emplace_back(reg(inst.X));
|
|
}
|
|
INST_NEXT;
|
|
case INST(MOVE): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
stack.emplace_back(std::move(reg(inst.X)));
|
|
}
|
|
INST_NEXT;
|
|
case INST(STORE): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
reg(inst.X) = pop(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(STOREN): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
TORCH_INTERNAL_ASSERT(stack.size() >= inst.N);
|
|
for (size_t i = inst.N; i > 0; --i) {
|
|
reg(inst.X + i - 1) = pop(stack);
|
|
}
|
|
}
|
|
INST_NEXT;
|
|
case INST(DROP): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
stack.pop_back();
|
|
}
|
|
INST_NEXT;
|
|
case INST(DROPR): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
reg(inst.X) = IValue();
|
|
}
|
|
INST_NEXT;
|
|
case INST(LOADC): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
stack.emplace_back(frame.function->constant_table_[inst.X]);
|
|
}
|
|
INST_NEXT;
|
|
case INST(GET_ATTR): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
const auto& userObj = stack.back().toObjectRef();
|
|
stack.back() = userObj.getSlot(inst.X);
|
|
}
|
|
INST_NEXT;
|
|
case INST(SET_ATTR): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
auto v = pop(stack);
|
|
auto& userObj = stack.back().toObjectRef();
|
|
userObj.setSlot(inst.X, std::move(v));
|
|
stack.pop_back();
|
|
}
|
|
INST_NEXT;
|
|
case INST(JF): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
if (pop(stack).toBool()) {
|
|
inst = instFetch(1);
|
|
} else {
|
|
inst = instFetch(inst.X);
|
|
}
|
|
}
|
|
INST_DISPATCH;
|
|
case INST(JMP): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
inst = instFetch(inst.X);
|
|
}
|
|
INST_DISPATCH;
|
|
case INST(LOOP): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
|
auto fr = stack.end() - (inst.N + 1);
|
|
int64_t trip_count = fr[0].toInt();
|
|
int64_t max_trip_count = fr[1].toInt();
|
|
bool cond = fr[2].toBool();
|
|
if (trip_count < max_trip_count && cond) {
|
|
fr[2] = trip_count;
|
|
fr[0] = trip_count + 1;
|
|
inst = instFetch(1);
|
|
} else {
|
|
size_t n_loop_carried = inst.N - 2;
|
|
for (const auto i : c10::irange(n_loop_carried)) {
|
|
fr[i] = std::move(fr[i + 3]);
|
|
}
|
|
drop(stack, 3); // iteration_count, max_iter, cond
|
|
inst = instFetch(inst.X);
|
|
}
|
|
}
|
|
INST_DISPATCH;
|
|
case INST(CALL): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
Function* fn = frame.function->function_table_[inst.X];
|
|
callFunction(*fn, stack);
|
|
continue;
|
|
}
|
|
case INST(INTERFACE_CALL): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
// note the hash table lookup to find the function
|
|
// this can be more optimized if necessary, caching parts
|
|
// of the hashing computation or storing the offset when
|
|
// the object is turned into an interface
|
|
|
|
// consider passing
|
|
// `frames.back().function->remaining_bailout_depth_` into
|
|
// `get_executor().getPlanFor()` to propagate caller's depth
|
|
// restrictions onto children while this strategy has a potential to
|
|
// reduce the number of compilations for too dynamic callers we
|
|
// might miss opportunities where a caller is dynamic but a callee
|
|
// gets stable arguments
|
|
Function& function =
|
|
peek(stack, 0, inst.N)
|
|
.toObject()
|
|
->type()
|
|
->getMethod(
|
|
frame.function->constant_table_[inst.X].toStringRef());
|
|
callFunction(function, stack);
|
|
continue;
|
|
}
|
|
case INST(RET): {
|
|
if (frames.size() > 1) {
|
|
leaveFrame();
|
|
continue;
|
|
}
|
|
if (future_) {
|
|
auto num_outputs = frames.back().function->n_outputs;
|
|
if (num_outputs == 1) {
|
|
future_->markCompleted(stack.back());
|
|
} else {
|
|
future_->markCompleted(
|
|
c10::ivalue::Tuple::create(jit::last(stack, num_outputs)));
|
|
}
|
|
}
|
|
// destroy the last frame and call RecordFunction's end callbacks
|
|
leaveFrame();
|
|
return false;
|
|
}
|
|
case INST(WAIT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
auto future = stack.back().toFuture();
|
|
if (!future->completed()) {
|
|
getOrCreateFuture();
|
|
|
|
// callback needs to be a struct rather than a lambda so that
|
|
// we can move the stack to the other thread
|
|
struct Callback {
|
|
Callback(
|
|
c10::intrusive_ptr<InterpreterStateImpl> state,
|
|
Stack stack)
|
|
: stateImpl_(std::move(state)),
|
|
state_(stateImpl_),
|
|
stack_(std::move(stack)),
|
|
dist_autograd_context_id_(getDistAutogradContextId()) {
|
|
state_ = InterpreterState(stateImpl_);
|
|
}
|
|
void operator()(c10::ivalue::Future& /* unused */) {
|
|
stateImpl_->taskLauncher_(InterpreterContinuation(
|
|
state_,
|
|
std::move(stack_),
|
|
dist_autograd_context_id_,
|
|
std::move(tls_state_)));
|
|
}
|
|
|
|
private:
|
|
c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
|
|
InterpreterState state_;
|
|
Stack stack_;
|
|
int64_t dist_autograd_context_id_;
|
|
// preserve the original ThreadLocalState
|
|
at::ThreadLocalState tls_state_;
|
|
};
|
|
|
|
// we are suspending, so we need to reset the stack to where we
|
|
// started if it started empty, except for the inputs we can avoid
|
|
// a true copy by swapping, which leaves the original stack empty.
|
|
Stack copied;
|
|
if (stack_start_ == 0) {
|
|
copied.swap(stack);
|
|
} else {
|
|
copied.insert(
|
|
copied.begin(),
|
|
std::make_move_iterator(stack.begin() + stack_start_),
|
|
std::make_move_iterator(stack.end()));
|
|
stack.resize(stack_start_);
|
|
}
|
|
// save pc into the frame so we continue here when restored
|
|
future->addCallback(
|
|
Callback(intrusive_from_this(), std::move(copied)));
|
|
|
|
return true;
|
|
}
|
|
stack.pop_back();
|
|
stack.emplace_back(future->value());
|
|
}
|
|
INST_NEXT;
|
|
case INST(PROFILE_OP): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
auto& frame_id_ref = frame.id;
|
|
if (!frame_id_ref.has_value()) {
|
|
frame_id_ref = Frame::genId();
|
|
}
|
|
const auto& callback =
|
|
frame.function->profile_function_table_[inst.X];
|
|
push(stack, c10::IValue{static_cast<int64_t>(*frame_id_ref)});
|
|
callback(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(FAIL_GUARD): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
// patch FAIL_GUARD back to GUARD
|
|
GRAPH_DEBUG(
|
|
"Bailout ", inst.X, " triggered via bailout_requests_!");
|
|
frame.function->instructions_[frame.pc].op = GUARD;
|
|
push(stack, false);
|
|
}
|
|
INST_NEXT;
|
|
case INST(TYPECHECK): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
unsigned num_inputs = inst.N, i = 0;
|
|
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
|
|
// Check every input's shape against profiled (expected) shape.
|
|
for (i = 0; i < num_inputs; i++) {
|
|
auto& input = peek(stack, i, num_inputs);
|
|
auto& t = input.toTensor();
|
|
const TypePtr& expected = frame.function->type_table_[inst.X + i];
|
|
auto* expected_type = expected->castRaw<TensorType>();
|
|
if (t.defined() && !expected_type->matchTensor(t)) {
|
|
push(stack, false);
|
|
break;
|
|
}
|
|
}
|
|
if (i == num_inputs) {
|
|
push(stack, true);
|
|
}
|
|
}
|
|
INST_NEXT;
|
|
case INST(GUARD): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
if (!stack.back().isTensor()) {
|
|
// stack.back() is an Uninitialized IValue and this is a guard
|
|
// on a block output. Uninitialized IValues are never used
|
|
// so it's safe to pass this guard check
|
|
push(stack, true);
|
|
} else {
|
|
auto& t = stack.back().toTensor();
|
|
const TypePtr& expected = frame.function->type_table_[inst.X];
|
|
auto* expected_type = expected->castRaw<TensorType>();
|
|
if (t.defined() &&
|
|
!frames.back().symbols2dims.bindSymbolicShapes(
|
|
t.sizes(), expected_type->symbolic_sizes())) {
|
|
push(stack, false);
|
|
} else {
|
|
push(stack, expected_type->matchTensor(t));
|
|
}
|
|
}
|
|
}
|
|
INST_NEXT;
|
|
case INST(TAIL_CALL): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
GRAPH_DEBUG("running TAIL_CALL for ", inst.X);
|
|
frame.function->function_table_[inst.X]->ensure_defined();
|
|
size_t remaining_bailout_depth =
|
|
frame.function->remaining_bailout_depth_ > 0
|
|
? frame.function->remaining_bailout_depth_ - 1
|
|
: 0;
|
|
auto& f = *frame.function->function_table_[inst.X];
|
|
size_t num_inputs = f.num_inputs();
|
|
size_t base_pointer = frame.base_pointer;
|
|
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
|
|
size_t inputs_start = stack.size() - num_inputs;
|
|
for (const auto i : c10::irange(num_inputs)) {
|
|
stack.at(base_pointer + i) =
|
|
std::move(stack.at(inputs_start + i));
|
|
}
|
|
stack.resize(base_pointer + num_inputs);
|
|
leaveFrame();
|
|
|
|
callFunction(f, stack, remaining_bailout_depth, false);
|
|
continue;
|
|
}
|
|
case INST(LIST_UNPACK): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
listUnpack(stack, inst.X);
|
|
}
|
|
INST_NEXT;
|
|
case INST(TUPLE_CONSTRUCT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
tupleConstruct(stack, inst.X);
|
|
}
|
|
INST_NEXT;
|
|
case INST(TUPLE_SLICE): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
tupleSlice(stack, inst.X, inst.X + inst.N);
|
|
}
|
|
INST_NEXT;
|
|
case INST(NAMED_TUPLE_CONSTRUCT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
namedTupleConstruct(
|
|
stack,
|
|
frame.function->type_table_[inst.X]->expect<TupleType>(),
|
|
inst.N);
|
|
}
|
|
INST_NEXT;
|
|
case INST(LIST_CONSTRUCT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
const auto& type =
|
|
frame.function->type_table_[inst.X]->expectRef<ListType>();
|
|
listConstruct(stack, type, inst.N);
|
|
}
|
|
INST_NEXT;
|
|
case INST(DICT_CONSTRUCT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
const auto& type =
|
|
frame.function->type_table_[inst.X]->expectRef<DictType>();
|
|
dictConstruct(stack, type, inst.N);
|
|
}
|
|
INST_NEXT;
|
|
case INST(CREATE_OBJECT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
auto type =
|
|
frame.function->type_table_[inst.X]->expect<ClassType>();
|
|
createObject(stack, type);
|
|
}
|
|
INST_NEXT;
|
|
case INST(ISINSTANCE): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
at::ArrayRef<TypePtr> types(
|
|
&frame.function->type_table_[inst.X],
|
|
&frame.function->type_table_[inst.X] + inst.N);
|
|
isinstance(stack, types);
|
|
}
|
|
INST_NEXT;
|
|
case INST(TUPLE_INDEX): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
tupleIndex(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(RAISE_EXCEPTION): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
raiseExceptionWithMessage(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(UNCHECKED_CAST): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
noop(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(__IS__): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
is(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(UN_INITIALIZED): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
unInitialized(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(__ISNOT__): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
isNot(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(FORMAT): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
format(stack, inst.X);
|
|
}
|
|
INST_NEXT;
|
|
case INST(DEVICE): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
device(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(DTYPE): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
TORCH_INTERNAL_ASSERT(!stack.empty());
|
|
dtype(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(DIM): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
TORCH_INTERNAL_ASSERT(!stack.empty());
|
|
dim(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(__NOT__): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
_not(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(DICT_INDEX): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
dictIndex(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(TO_LIST): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
toList(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(NUM_TO_TENSOR): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
numToTensorScalar(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(IS_CUDA): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
isCuda(stack);
|
|
}
|
|
INST_NEXT;
|
|
case INST(FORK): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
// Move inputs to a separate stack
|
|
auto& forked_fn =
|
|
toGraphFunction(*frame.function->function_table_[inst.X]);
|
|
InterpreterState forked_interpreter(
|
|
forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_);
|
|
InterpreterContinuation continuation(
|
|
forked_interpreter,
|
|
Stack(stack.end() - inst.N, stack.end()),
|
|
getDistAutogradContextId());
|
|
drop(stack, inst.N);
|
|
push(stack, forked_interpreter.getFuture());
|
|
taskLauncher_(std::move(continuation));
|
|
}
|
|
INST_NEXT;
|
|
case INST(AWAITABLE): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
auto fn_ptr = frame.function->function_table_[inst.X];
|
|
auto& fn = toGraphFunction(*fn_ptr);
|
|
auto num_outputs = fn.graph()->outputs().size();
|
|
TypePtr out_type;
|
|
if (num_outputs == 1) {
|
|
out_type = fn.graph()->outputs()[0]->type();
|
|
} else {
|
|
std::vector<TypePtr> out_types;
|
|
for (const auto& o : fn.graph()->outputs()) {
|
|
out_types.push_back(o->type());
|
|
}
|
|
out_type = TupleType::create(out_types);
|
|
}
|
|
auto args = std::vector<IValue>(stack.end() - inst.N, stack.end());
|
|
auto aw = c10::make_intrusive<c10::ivalue::Await>(out_type);
|
|
aw->setArgs(std::move(args));
|
|
aw->setFn(
|
|
[&args = aw->args(),
|
|
fn_ptr,
|
|
taskLauncher = taskLauncher_]() -> IValue {
|
|
auto& fn = toGraphFunction(*fn_ptr);
|
|
auto n_out = fn.graph()->outputs().size();
|
|
torch::jit::Stack s;
|
|
for (const auto& arg : args) {
|
|
s.push_back(arg);
|
|
}
|
|
InterpreterState await_interpreter(
|
|
fn.get_executor().getPlanFor(s).code, taskLauncher);
|
|
await_interpreter.run(s);
|
|
if (n_out == 1) {
|
|
return s.back();
|
|
}
|
|
return c10::ivalue::Tuple::create(jit::last(s, n_out));
|
|
});
|
|
drop(stack, inst.N);
|
|
push(stack, std::move(aw));
|
|
}
|
|
INST_NEXT;
|
|
case INST(WARN): {
|
|
[[maybe_unused]] auto _ = instGuard();
|
|
// Keeps track of which WARN instruction has been executed before,
|
|
// we only want to execute each WARN once to match default Python
|
|
// warning behavior.
|
|
bool need_warn = true;
|
|
if (inst.X != -1) {
|
|
need_warn = warned_nodes_.insert(inst.X);
|
|
}
|
|
|
|
Node* node =
|
|
frames.back().function->instructions_source_.at(frame.pc);
|
|
auto range = node->sourceRange().source();
|
|
if (range->filename()) {
|
|
drop(stack, 1);
|
|
const auto& msg = stack.back().toStringRef();
|
|
if (need_warn) {
|
|
auto line = range->starting_line_no() +
|
|
range->lineno_for_offset(node->sourceRange().start());
|
|
c10::SourceLocation location{
|
|
"", range->filename()->c_str(), uint32_t(line)};
|
|
// Sends the warning to the warning handler with the
|
|
// "verbatim" flag. This flag ensures the warning handler
|
|
// will print the exception as configured.
|
|
c10::warn(c10::Warning(
|
|
c10::UserWarning(), location, msg, /*verbatim=*/true));
|
|
}
|
|
stack.pop_back();
|
|
} else {
|
|
if (need_warn) {
|
|
TORCH_WARN(stack.back().toStringRef());
|
|
}
|
|
stack.pop_back();
|
|
}
|
|
}
|
|
INST_NEXT;
|
|
}
|
|
}
|
|
} catch (std::exception& e) {
|
|
for (auto it = entered_objects.rbegin(), end = entered_objects.rend();
|
|
it != end;
|
|
++it) {
|
|
auto& f = it->toObject()->type()->getMethod("__exit__");
|
|
Stack stack;
|
|
push(stack, *it);
|
|
push(stack, IValue());
|
|
push(stack, IValue());
|
|
push(stack, IValue());
|
|
try {
|
|
f.run(stack);
|
|
} catch (std::exception& _) {
|
|
// TODO(T98048876): Handle `_` correctly.
|
|
}
|
|
}
|
|
if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
|
|
if (future_) {
|
|
future_->setError(std::current_exception());
|
|
return false;
|
|
}
|
|
throw;
|
|
}
|
|
auto* jit_exception = dynamic_cast<JITException*>(&e);
|
|
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
|
|
auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
|
|
|
|
std::optional<std::string> python_class_name;
|
|
if (jit_exception) {
|
|
python_class_name = jit_exception->getPythonClassName();
|
|
}
|
|
handleError(
|
|
e, (bool)jit_exception, not_implemented_error, python_class_name);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
#undef INST_NEXT
|
|
#undef INST_DISPATCH
|
|
#undef INST
|
|
#undef JIT_USE_COMPUTED_GOTO
|
|
|
|
bool runImpl(Stack& stack) {
|
|
if (!profiling::isProfilingOngoing()) {
|
|
return runTemplate</*EnableProfiling*/ false>(stack);
|
|
} else {
|
|
return runTemplate</*EnableProfiling*/ true>(stack);
|
|
}
|
|
}
|
|
|
|
void formatStackTrace(std::ostream& out) {
|
|
format_stack_trace(out, callstack());
|
|
}
|
|
|
|
void handleError(
|
|
const std::exception& e,
|
|
bool is_jit_exception,
|
|
c10::NotImplementedError* not_implemented_error,
|
|
std::optional<std::string> python_class_name) {
|
|
ExceptionMessage msg(e);
|
|
std::ostringstream ss;
|
|
std::string class_name =
|
|
python_class_name ? *python_class_name : "RuntimeError";
|
|
ss << "The following operation failed in the TorchScript interpreter.\n";
|
|
formatStackTrace(ss);
|
|
ss << class_name << ": " << msg << "\n";
|
|
if (future_) {
|
|
future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
|
|
} else if (is_jit_exception) {
|
|
// save the original exception's message when creating a new JITException
|
|
throw JITException(ss.str(), python_class_name, e.what());
|
|
} else if (not_implemented_error) {
|
|
throw c10::NotImplementedError(
|
|
ss.str(),
|
|
not_implemented_error->backtrace(),
|
|
not_implemented_error->caller());
|
|
} else {
|
|
if (get_cpp_stacktraces_enabled()) {
|
|
ss << e.what() << "\n";
|
|
}
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
|
|
static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
|
|
if (!frame.record_function) {
|
|
auto step_callbacks = at::getStepCallbacksUnlessEmpty(
|
|
at::RecordScope::TORCHSCRIPT_FUNCTION);
|
|
if (C10_UNLIKELY(step_callbacks.has_value())) {
|
|
auto rec_fn =
|
|
std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
|
|
if (rec_fn->needsInputs()) {
|
|
rec_fn->before(
|
|
frame.function->function_name_,
|
|
last(stack, frame.function->n_inputs));
|
|
} else {
|
|
rec_fn->before(frame.function->function_name_);
|
|
}
|
|
frame.record_function = std::move(rec_fn);
|
|
}
|
|
}
|
|
}
|
|
|
|
public:
|
|
// One way to avoid overhead of forming string would be to return
|
|
// a vector of frame.function, i.e. CodeImpl*
|
|
// This is not exactly clean as it will expose, internal details of
|
|
// interpreter. But this way we hold onto graph/node and Function and
|
|
// we can create module hierarchy string for each event in autograd
|
|
// profiler at the end, when consolidating events.
|
|
// At the moment overhead does not seem exhorbitantly large.
|
|
// Another option would be return vector of (string, InlinedCallstackPtrs)
|
|
// string would contain function name and typename of self
|
|
// Format of the returned vector of strings:
|
|
// For each frame, the corresponding module name, type and function name
|
|
// are in following format:
|
|
// <module-instance-name>(module type)::<function-name>
|
|
// Special keys for module-instance-name:
|
|
// - TOP: for top level module
|
|
// - SELF: When method/function of the frame is associated with
|
|
// previous frame's module instance
|
|
// - INSTANCE_NAME_UNKNOWN: instance name cannot be figured out
|
|
// - CALL_FUNCTION: call to free function
|
|
std::vector<std::string> moduleHierarchy() const {
|
|
std::vector<std::string> module_function_list;
|
|
std::string module_hierarchy("TOP");
|
|
for (size_t i = 0; i < frames.size(); ++i) {
|
|
const Frame& frame = frames[i];
|
|
std::string fn_name = frame.function->function_name_;
|
|
// For each frame, type of the class with which the function is
|
|
// associated, is queried here. And the type name is added to
|
|
// module hierarchy.
|
|
const auto& g = frame.function->graph_;
|
|
std::string g_self_type;
|
|
if (g && !g->inputs().empty()) {
|
|
const auto& g_self_type_ptr =
|
|
g->inputs()[0]->type()->cast<c10::ClassType>();
|
|
if (g_self_type_ptr) {
|
|
g_self_type = g_self_type_ptr->name()->qualifiedName();
|
|
g_self_type = g_self_type.substr(g_self_type.find_last_of('.') + 1);
|
|
}
|
|
}
|
|
module_hierarchy.append("(")
|
|
.append(g_self_type)
|
|
.append(")::")
|
|
.append(fn_name);
|
|
module_function_list.emplace_back(std::move(module_hierarchy));
|
|
|
|
size_t pc = frame.pc;
|
|
// CALL nodes have already advanced the pc, so
|
|
// undo that to report the call node
|
|
if (i + 1 < frames.size()) {
|
|
--pc;
|
|
}
|
|
|
|
Node* node = frame.function->instructions_source_[pc];
|
|
if (node->callstack()) {
|
|
for (const auto& p : (*node->callstack())->vec()) {
|
|
fn_name = std::get<0>(p)->name();
|
|
const auto& opt_module_info = std::get<2>(p);
|
|
if (opt_module_info.has_value()) {
|
|
const auto& module_instance_info = opt_module_info.value();
|
|
module_hierarchy = utils::get_module_info(module_instance_info);
|
|
module_hierarchy.append("::").append(fn_name);
|
|
} else {
|
|
// This is likely a call to free function, not associated with
|
|
// any class
|
|
module_hierarchy = "::";
|
|
module_hierarchy.append(fn_name);
|
|
}
|
|
module_function_list.emplace_back(std::move(module_hierarchy));
|
|
}
|
|
}
|
|
|
|
module_hierarchy = std::string();
|
|
// If this node is of type callMethod then the following frame
|
|
// will contain the op being executed.
|
|
// For such callMethod node, we add the object instance name
|
|
// associated with it, since the following frame will not have it.
|
|
if (node->kind() == prim::CallMethod) {
|
|
std::string class_instance_name;
|
|
if (node->input(0)->node()->kind() == prim::GetAttr) {
|
|
class_instance_name = node->input(0)->node()->s(attr::name);
|
|
} else if (
|
|
!node->owningGraph()->inputs().empty() &&
|
|
node->input(0) == node->owningGraph()->inputs()[0]) {
|
|
class_instance_name = "SELF";
|
|
} else {
|
|
class_instance_name = "INSTANCE_NAME_UNKNOWN";
|
|
}
|
|
module_hierarchy = std::move(class_instance_name);
|
|
} else if (node->kind() == prim::CallFunction) {
|
|
auto function_constant = node->input(0)->node();
|
|
auto fun_type =
|
|
function_constant->output()->type()->expect<FunctionType>();
|
|
auto fun_name = fun_type->function()->name();
|
|
module_hierarchy = "CALL_FUNCTION::";
|
|
module_hierarchy.append(fun_name);
|
|
}
|
|
}
|
|
return module_function_list;
|
|
}
|
|
|
|
std::vector<StackEntry> callstack() const {
|
|
std::vector<StackEntry> entries;
|
|
for (const auto i : c10::irange(frames.size())) {
|
|
const Frame& frame = frames[i];
|
|
std::string previous_fn_name = frame.function->function_name_;
|
|
size_t pc = frame.pc;
|
|
// CALL nodes have already advanced the pc, so
|
|
// undo that to report the call node
|
|
if (i + 1 < frames.size()) {
|
|
--pc;
|
|
}
|
|
|
|
Node* node = frame.function->instructions_source_[pc];
|
|
if (node->callstack()) {
|
|
for (const auto& p : (*node->callstack())->vec()) {
|
|
entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)});
|
|
previous_fn_name = std::get<0>(p)->name();
|
|
}
|
|
}
|
|
entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()});
|
|
}
|
|
return entries;
|
|
}
|
|
|
|
c10::intrusive_ptr<Future> getOrCreateFuture() {
|
|
if (!future_) {
|
|
future_ =
|
|
c10::make_intrusive<Future>(frames.front().function->return_type_);
|
|
}
|
|
return future_;
|
|
}
|
|
|
|
c10::intrusive_ptr<Future> runAsync(Stack& stack) {
|
|
getOrCreateFuture();
|
|
runImpl(stack);
|
|
return future_;
|
|
}
|
|
|
|
void run(Stack& stack) {
|
|
// By the time the continuation completes the frame will be gone, so this
|
|
// must be done before calling runImpl().
|
|
TORCH_INTERNAL_ASSERT(!frames.empty());
|
|
const auto num_outputs = frames.front().function->n_outputs;
|
|
if (runImpl(stack)) {
|
|
future_->wait();
|
|
|
|
if (num_outputs == 1) {
|
|
push(stack, future_->value());
|
|
} else {
|
|
auto tuple = future_->value().toTuple();
|
|
for (const IValue& value : tuple->elements()) {
|
|
push(stack, value);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
std::vector<StackEntry> currentCallstack() {
|
|
if (tls_int_state_ptr_) {
|
|
auto cs = tls_int_state_ptr_->callstack();
|
|
std::reverse(cs.begin(), cs.end());
|
|
return cs;
|
|
}
|
|
return std::vector<StackEntry>();
|
|
}
|
|
|
|
std::vector<std::string> currentModuleHierarchy() {
|
|
if (tls_int_state_ptr_) {
|
|
return tls_int_state_ptr_->moduleHierarchy();
|
|
}
|
|
return std::vector<std::string>();
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Code& code) {
|
|
out << *code.pImpl->graph_ << "\n";
|
|
code.pImpl->dump(out);
|
|
return out;
|
|
}
|
|
|
|
Code::Code(
|
|
const std::shared_ptr<Graph>& graph,
|
|
std::string function_name,
|
|
size_t remaining_bailout_depth)
|
|
: pImpl(new CodeImpl(
|
|
graph,
|
|
std::move(function_name),
|
|
remaining_bailout_depth)) {}
|
|
|
|
Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {}
|
|
|
|
MobileCode::MobileCode(
|
|
const std::shared_ptr<Graph>& graph,
|
|
std::string function_name,
|
|
bool emit_default_input_instructions,
|
|
bool support_default_args_before_out,
|
|
bool emit_promoted_ops,
|
|
size_t remaining_bailout_depth)
|
|
: Code(new interpreter::MobileCodeImpl(
|
|
graph,
|
|
std::move(function_name),
|
|
emit_default_input_instructions,
|
|
support_default_args_before_out,
|
|
emit_promoted_ops,
|
|
remaining_bailout_depth)) {}
|
|
|
|
const std::vector<GraphExecutor*>& Code::grad_executors() {
|
|
return pImpl->grad_executors();
|
|
}
|
|
|
|
const std::vector<GraphExecutor*>& Code::diff_graph_op_executors() {
|
|
return pImpl->diff_graph_op_executors();
|
|
}
|
|
|
|
size_t Code::num_bailouts() const {
|
|
return pImpl->type_table_.size();
|
|
}
|
|
|
|
void Code::request_bailout(size_t index) {
|
|
pImpl->request_bailout(index);
|
|
}
|
|
|
|
size_t Code::num_inputs() const {
|
|
return pImpl->n_inputs;
|
|
}
|
|
|
|
size_t Code::num_outputs() const {
|
|
return pImpl->n_outputs;
|
|
}
|
|
|
|
const std::vector<c10::IValue>& Code::constant_table() const {
|
|
return pImpl->constant_table();
|
|
}
|
|
|
|
const std::vector<Instruction>& Code::instructions() const {
|
|
return pImpl->instructions();
|
|
}
|
|
|
|
const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
|
|
const {
|
|
return pImpl->op_to_num_specified_args();
|
|
}
|
|
|
|
const std::vector<Node*>& Code::instructions_source() const {
|
|
return pImpl->instructions_source();
|
|
}
|
|
|
|
const std::vector<TypePtr>& Code::type_table() const {
|
|
return pImpl->type_table_;
|
|
}
|
|
|
|
size_t Code::register_size() const {
|
|
return pImpl->register_size_;
|
|
}
|
|
|
|
std::shared_ptr<Graph> Code::graph() const {
|
|
return pImpl->preprocess_.graph;
|
|
}
|
|
|
|
InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
|
|
: pImpl(c10::make_intrusive<InterpreterStateImpl>(
|
|
code,
|
|
std::move(taskLauncher))) {}
|
|
|
|
void InterpreterState::run(Stack& stack) {
|
|
static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
|
|
}
|
|
|
|
c10::intrusive_ptr<Future> InterpreterState::runAsync(Stack& stack) {
|
|
return static_cast<InterpreterStateImpl*>(pImpl.get())->runAsync(stack);
|
|
}
|
|
|
|
c10::intrusive_ptr<Future> InterpreterState::getFuture() {
|
|
return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
|
|
}
|
|
|
|
InterpreterState::InterpreterState(
|
|
c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
|
|
: pImpl(std::move(pImpl_)) {}
|
|
|
|
void InterpreterContinuation::operator()() {
|
|
#ifdef USE_RPC
|
|
auto prev_dist_id = DistAutogradContainer::currentContextId();
|
|
DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_);
|
|
#endif
|
|
if (tls_state_ != std::nullopt) {
|
|
at::ThreadLocalStateGuard g(*tls_state_);
|
|
state.runAsync(stack);
|
|
} else {
|
|
state.runAsync(stack);
|
|
}
|
|
#ifdef USE_RPC
|
|
DistAutogradContainer::forceCurrentContextId(prev_dist_id);
|
|
#endif
|
|
}
|
|
|
|
} // namespace torch::jit
|