#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_RPC #include using torch::distributed::autograd::DistAutogradContainer; #endif #include #include #include #include #include #include #include #include #include #include C10_DEFINE_bool( torch_jit_enable_rethrow_caught_exception, false, "enable rethrowing caught exception"); namespace torch { namespace jit { using CodeImpl = interpreter::CodeImpl; // Before we translate to intepreter instructions, we do // some preprocessing of the graph to turn it into a form that is closer // to what the instructions will look like. // In particular we: // * Computes whether a input to a node is the last use, so we can issue MOVE // rather than LOAD instructions. // * Drop nodes are inserted for any node that is unused to create a dummy use // that will cause the interpreter to free the node. // A drop node just pops its input off the stack to ensure the interpreter // releases references to nodes that are never used. Drop nodes are also // inserted when the last use of a node is in some conditionally run control // flow (e.g. one side of an If) and the interpreter must free the node only // after the control flow has reconverged // Outputs are: // * graph - the post processed copy of g // * move_flags[n] - a list of booleans, one for each input, // indicating whether this is the last use of the value. The interpreter // should generate a move rather than a copy in this case. TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) { if (!t.defined()) { return TensorType::get()->withUndefined(); } auto r = TensorType::create(t); if (!at::GradMode::is_enabled()) { return r->withRequiresGrad(false); } return r; } namespace { inline int64_t getDistAutogradContextId() { #ifdef USE_RPC return DistAutogradContainer::currentContextId(); #else return 0; #endif } } // namespace thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr; struct TLSCurrentInterpreterGuard { TLSCurrentInterpreterGuard(InterpreterStateImpl* state) { prev_state_ = tls_int_state_ptr_; tls_int_state_ptr_ = state; } ~TLSCurrentInterpreterGuard() { tls_int_state_ptr_ = prev_state_; } private: InterpreterStateImpl* prev_state_; }; // InterpreterState state that and used to compute a Code struct InterpreterStateImpl : c10::intrusive_ptr_target { InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher) : taskLauncher_(std::move(taskLauncher)) { enterFrame(code, 0); } private: using Frame = torch::jit::interpreter::Frame; struct WarnedNodes { public: // Inserts idx into warned_nodes_, returns a boolean indicates whether // insertion actually happened (idx wasn't originally in the set). bool insert(int32_t idx) { std::unique_lock lock(mutex_); return warned_nodes_.insert(idx).second; } private: std::mutex mutex_; std::unordered_set warned_nodes_; }; WarnedNodes warned_nodes_; // if we need to suspend, where do we reset the stack? // answer: to where it was when we were called, not // including any inputs to this function int64_t stack_start_ = -1; c10::intrusive_ptr future_; TaskLauncher taskLauncher_; // this holds all the tensors for this interpreter run // we don't bother minimizing the size of this vector, since the extra // memory used by the pointers in this will be small // instead we are very aggresive about releasing tensors when they become dead // to make sure memory management happens efficiently. // We optimize for the case where derivatives are run with retain_graph=False // in the case where it is true, then the interpreter and this array get // copied if this every becomes a bottleneck then we _should_ consider // minimizing the total number or register std::vector registers; // A stack of objects that have been __enter__'d. std::vector entered_objects; std::vector frames; c10::intrusive_ptr intrusive_from_this() { c10::raw::intrusive_ptr::incref(this); return c10::intrusive_ptr::reclaim(this); } void enterFrame(const Code& code, size_t base_pointer) { frames.emplace_back(Frame{code.pImpl, 0, base_pointer, c10::nullopt}); registers.resize(registers.size() + code.pImpl->register_size_); } void leaveFrame() { registers.resize(registers.size() - frames.back().function->register_size_); frames.pop_back(); } void callFunction( Function& f, Stack& stack, c10::optional bailOut = c10::nullopt, bool next = true) { bool newFrame = f.call(stack, bailOut, [&](const Code& code) { enterFrame(code, stack.size() - code.num_inputs()); checkAndStartRecordFunction(frames.back(), stack); }); if (next) { (frames.rbegin() + (newFrame ? 1 : 0))->pc++; } } // relative to the end of the register list so that when we call // functions we are referring to the registers of the currenly executing // function. IValue& reg(size_t reg) { return *(registers.end() - reg); } void dump(std::ostream& out, const Stack& stack) const { out << "Stack:\n"; for (const auto& val : stack) { out << val; out << "\n"; } } #if defined(__GNUC__) || defined(__clang__) #define JIT_USE_COMPUTED_GOTO #endif // Primitives for making interpreter internal state transitions. // We maintain two local variables as the internal interpreter state: // `frame` will be the current frame that the interpreter operatos on. // `inst` will the current instruction pointed to by program counter. // // Instruction blocks should be always declared through `INST` macro and // the instruction body should always start with a `INST_GUARD` declaration. // Also blocks should be ended properly with either `INST_NEXT` (for going // to the next instruction), or `INST_DISPATCH` (for jumping to a computed // position using `INST_FETCH`). #define INST_FETCH(X) (frame.function->instructions_[frame.pc += (X)]) #define INST_GUARD \ profiling::InstructionSpan span { \ *frame.function->instructions_source()[frame.pc] \ } #if defined(JIT_USE_COMPUTED_GOTO) #define INST(NAME) \ NAME: \ label_##NAME #define INST_DISPATCH goto* dispatch_table[inst.op] #else #define INST(NAME) NAME #define INST_DISPATCH break #endif #define INST_NEXT \ inst = INST_FETCH(1); \ INST_DISPATCH bool runImpl(Stack& stack) { // if we have never run before, then we might have to return the // stack when we suspend, record where it starts so we return the right // stack if (stack_start_ == -1) { TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs); stack_start_ = stack.size() - frames.back().function->n_inputs; } else { // during restarts, all of the stack is always our own, so we leave // nothing stack_start_ = 0; } TLSCurrentInterpreterGuard g(this); if (frames.back().pc == 0 && stack_start_ == 0) { checkAndStartRecordFunction(frames.back(), stack); } #if defined(JIT_USE_COMPUTED_GOTO) // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) static void* dispatch_table[] = { #define DISPATCH_TABLE_ENTRY(op, _) &&label_##op, FORALL_OPCODES(DISPATCH_TABLE_ENTRY) #undef DISPATCH_TABLE_ENTRY }; #endif try { while (true) { Frame& frame = frames.back(); Instruction inst = INST_FETCH(0); switch (inst.op) { case INST(ENTER): { INST_GUARD; const auto& obj = peek(stack, 0, 1); TORCH_INTERNAL_ASSERT(obj.isObject()); entered_objects.push_back(obj); } INST_NEXT; case INST(EXIT): { INST_GUARD; auto obj = entered_objects.back().toObject(); auto& f = obj->type()->getMethod("__exit__"); push(stack, std::move(obj)); entered_objects.pop_back(); push(stack, IValue()); push(stack, IValue()); push(stack, IValue()); callFunction(f, stack); continue; } case INST(OP): { INST_GUARD; #ifndef NDEBUG size_t init_size = stack.size(); #endif frame.function->operator_table_[inst.X](stack); #ifndef NDEBUG frame.function->assert_stack_size(inst.X, init_size, stack.size()); #endif } INST_NEXT; case INST(OPN): { INST_GUARD; stack.push_back(inst.N); #ifndef NDEBUG size_t init_size = stack.size(); #endif frame.function->operator_table_[inst.X](stack); #ifndef NDEBUG frame.function->assert_stack_size(inst.X, init_size, stack.size()); #endif } INST_NEXT; case INST(LOAD): { INST_GUARD; stack.emplace_back(reg(inst.X)); } INST_NEXT; case INST(MOVE): { INST_GUARD; stack.emplace_back(std::move(reg(inst.X))); } INST_NEXT; case INST(STORE): { INST_GUARD; reg(inst.X) = pop(stack); } INST_NEXT; case INST(STOREN): { INST_GUARD; for (size_t i = inst.N; i > 0; --i) { reg(inst.X + i - 1) = pop(stack); } } INST_NEXT; case INST(DROP): { INST_GUARD; stack.pop_back(); } INST_NEXT; case INST(DROPR): { INST_GUARD; reg(inst.X) = IValue(); } INST_NEXT; case INST(LOADC): { INST_GUARD; stack.emplace_back(frame.function->constant_table_[inst.X]); } INST_NEXT; case INST(GET_ATTR): { INST_GUARD; const auto& userObj = stack.back().toObjectRef(); stack.back() = userObj.getSlot(inst.X); } INST_NEXT; case INST(SET_ATTR): { INST_GUARD; auto v = pop(stack); auto& userObj = stack.back().toObjectRef(); userObj.setSlot(inst.X, std::move(v)); stack.pop_back(); } INST_NEXT; case INST(JF): { INST_GUARD; if (pop(stack).toBool()) { inst = INST_FETCH(1); } else { inst = INST_FETCH(inst.X); } } INST_DISPATCH; case INST(JMP): { INST_GUARD; inst = INST_FETCH(inst.X); } INST_DISPATCH; case INST(LOOP): { INST_GUARD; // stack: iteration_count, max_iter, cond, loop_carried_deps... auto fr = stack.end() - (inst.N + 1); int64_t trip_count = fr[0].toInt(); int64_t max_trip_count = fr[1].toInt(); bool cond = fr[2].toBool(); if (trip_count < max_trip_count && cond) { fr[2] = trip_count; fr[0] = trip_count + 1; inst = INST_FETCH(1); } else { size_t n_loop_carried = inst.N - 2; for (const auto i : c10::irange(n_loop_carried)) { fr[i] = std::move(fr[i + 3]); } drop(stack, 3); // iteration_count, max_iter, cond inst = INST_FETCH(inst.X); } } INST_DISPATCH; case INST(CALL): { INST_GUARD; Function* fn = frame.function->function_table_[inst.X]; callFunction(*fn, stack); continue; } case INST(INTERFACE_CALL): { INST_GUARD; // note the hash table lookup to find the function // this can be more optimized if necessary, caching parts // of the hashing computation or storing the offset when // the object is turned into an interface // consider passing // `frames.back().function->remaining_bailout_depth_` into // `get_executor().getPlanFor()` to propagate caller's depth // restrictions onto children while this strategy has a potential to // reduce the number of compilations for too dynamic callers we // might miss opportunities where a caller is dynamic but a callee // gets stable arguments Function& function = peek(stack, 0, inst.N) .toObject() ->type() ->getMethod( frame.function->constant_table_[inst.X].toStringRef()); callFunction(function, stack); continue; } case INST(RET): { if (frames.size() > 1) { leaveFrame(); continue; } if (future_) { auto num_outputs = frames.back().function->n_outputs; if (num_outputs == 1) { future_->markCompleted(stack.back()); } else { future_->markCompleted( c10::ivalue::Tuple::create(jit::last(stack, num_outputs))); } } // destroy the last frame and call RecordFunction's end callbacks leaveFrame(); return false; } case INST(WAIT): { INST_GUARD; auto future = stack.back().toFuture(); if (!future->completed()) { getOrCreateFuture(); // callback needs to be a struct rather than a lambda so that // we can move the stack to the other thread struct Callback { Callback( c10::intrusive_ptr state, Stack stack) : stateImpl_(std::move(state)), state_(stateImpl_), stack_(std::move(stack)) { dist_autograd_context_id_ = getDistAutogradContextId(); state_ = InterpreterState(stateImpl_); } void operator()(c10::ivalue::Future& /* unused */) { stateImpl_->taskLauncher_(InterpreterContinuation( state_, std::move(stack_), dist_autograd_context_id_, std::move(tls_state_))); } private: c10::intrusive_ptr stateImpl_; InterpreterState state_; Stack stack_; int64_t dist_autograd_context_id_; // preserve the original ThreadLocalState at::ThreadLocalState tls_state_; }; // we are suspending, so we need to reset the stack to where we // started if it started empty, except for the inputs we can avoid // a true copy by swapping, which leaves the original stack empty. Stack copied; if (stack_start_ == 0) { copied.swap(stack); } else { copied.insert( copied.begin(), std::make_move_iterator(stack.begin() + stack_start_), std::make_move_iterator(stack.end())); stack.resize(stack_start_); } // save pc into the frame so we continue here when restored future->addCallback( Callback(intrusive_from_this(), std::move(copied))); return true; } stack.pop_back(); stack.emplace_back(future->value()); } INST_NEXT; case INST(PROFILE_OP): { INST_GUARD; auto& frame_id_ref = frame.id; if (!frame_id_ref.has_value()) { frame_id_ref = Frame::genId(); } const auto& callback = frame.function->profile_function_table_[inst.X]; push(stack, c10::IValue{static_cast(*frame_id_ref)}); callback(stack); } INST_NEXT; case INST(FAIL_GUARD): { INST_GUARD; // patch FAIL_GUARD back to GUARD GRAPH_DEBUG( "Bailout ", inst.X, " triggered via bailout_requests_!"); frame.function->instructions_[frame.pc].op = GUARD; push(stack, false); } INST_NEXT; case INST(TYPECHECK): { INST_GUARD; int num_inputs = inst.N, i = 0; // NOLINTNEXTLINE(clang-diagnostic-sign-compare) TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0); // Check every input's shape against profiled (expected) shape. for (i = 0; i < num_inputs; i++) { auto& input = peek(stack, i, num_inputs); auto& t = input.toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X + i]; auto* expected_type = expected->castRaw(); if (t.defined() && !expected_type->matchTensor(t)) { push(stack, false); break; } } if (i == num_inputs) { push(stack, true); } } INST_NEXT; case INST(GUARD): { INST_GUARD; if (!stack.back().isTensor()) { // stack.back() is an Uninitialized IValue and this is a guard // on a block output. Uninitialized IValues are never used // so it's safe to pass this guard check push(stack, true); } else { auto& t = stack.back().toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X]; auto* expected_type = expected->castRaw(); if (t.defined() && !frames.back().symbols2dims.bindSymbolicShapes( t.sizes(), expected_type->symbolic_sizes())) { push(stack, false); } else { push(stack, expected_type->matchTensor(t)); } } } INST_NEXT; case INST(TAIL_CALL): { INST_GUARD; GRAPH_DEBUG("running TAIL_CALL for ", inst.X); frame.function->function_table_[inst.X]->ensure_defined(); size_t remaining_bailout_depth = frame.function->remaining_bailout_depth_ > 0 ? frame.function->remaining_bailout_depth_ - 1 : 0; auto& f = *frame.function->function_table_[inst.X]; size_t num_inputs = f.num_inputs(); size_t base_pointer = frame.base_pointer; TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs); size_t inputs_start = stack.size() - num_inputs; for (const auto i : c10::irange(num_inputs)) { stack.at(base_pointer + i) = std::move(stack.at(inputs_start + i)); } stack.resize(base_pointer + num_inputs); leaveFrame(); callFunction(f, stack, remaining_bailout_depth, false); continue; } case INST(LIST_UNPACK): { INST_GUARD; listUnpack(stack, inst.X); } INST_NEXT; case INST(TUPLE_CONSTRUCT): { INST_GUARD; tupleConstruct(stack, inst.X); } INST_NEXT; case INST(TUPLE_SLICE): { INST_GUARD; tupleSlice(stack, inst.X, inst.X + inst.N); } INST_NEXT; case INST(NAMED_TUPLE_CONSTRUCT): { INST_GUARD; namedTupleConstruct( stack, frame.function->type_table_[inst.X]->expect(), inst.N); } INST_NEXT; case INST(LIST_CONSTRUCT): { INST_GUARD; const auto& type = frame.function->type_table_[inst.X]->expectRef(); listConstruct(stack, type, inst.N); } INST_NEXT; case INST(DICT_CONSTRUCT): { INST_GUARD; const auto& type = frame.function->type_table_[inst.X]->expectRef(); dictConstruct(stack, type, inst.N); } INST_NEXT; case INST(CREATE_OBJECT): { INST_GUARD; auto type = frame.function->type_table_[inst.X]->expect(); createObject(stack, type); } INST_NEXT; case INST(ISINSTANCE): { INST_GUARD; at::ArrayRef types( &frame.function->type_table_[inst.X], &frame.function->type_table_[inst.X] + inst.N); isinstance(stack, types); } INST_NEXT; case INST(TUPLE_INDEX): { INST_GUARD; tupleIndex(stack); } INST_NEXT; case INST(RAISE_EXCEPTION): { INST_GUARD; raiseExceptionWithMessage(stack); } INST_NEXT; case INST(UNCHECKED_CAST): { INST_GUARD; noop(stack); } INST_NEXT; case INST(__IS__): { INST_GUARD; is(stack); } INST_NEXT; case INST(UN_INITIALIZED): { INST_GUARD; unInitialized(stack); } INST_NEXT; case INST(__ISNOT__): { INST_GUARD; isNot(stack); } INST_NEXT; case INST(FORMAT): { INST_GUARD; format(stack, inst.X); } INST_NEXT; case INST(DEVICE): { INST_GUARD; device(stack); } INST_NEXT; case INST(DTYPE): { INST_GUARD; dtype(stack); } INST_NEXT; case INST(DIM): { INST_GUARD; dim(stack); } INST_NEXT; case INST(__NOT__): { INST_GUARD; _not(stack); } INST_NEXT; case INST(DICT_INDEX): { INST_GUARD; dictIndex(stack); } INST_NEXT; case INST(TO_LIST): { INST_GUARD; toList(stack); } INST_NEXT; case INST(NUM_TO_TENSOR): { INST_GUARD; numToTensorScalar(stack); } INST_NEXT; case INST(IS_CUDA): { INST_GUARD; isCuda(stack); } INST_NEXT; case INST(FORK): { INST_GUARD; // Move inputs to a separate stack auto& forked_fn = toGraphFunction(*frame.function->function_table_[inst.X]); InterpreterState forked_interpreter( forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_); InterpreterContinuation continuation( forked_interpreter, Stack(stack.end() - inst.N, stack.end()), getDistAutogradContextId()); drop(stack, inst.N); push(stack, forked_interpreter.getFuture()); taskLauncher_(std::move(continuation)); } INST_NEXT; case INST(WARN): { INST_GUARD; // Keeps track of which WARN instruction has been executed before, // we only want to execute each WARN once to match default Python // warning behavior. bool need_warn = true; if (inst.X != -1) { need_warn = warned_nodes_.insert(inst.X); } Node* node = frames.back().function->instructions_source_.at(frame.pc); auto range = node->sourceRange().source(); if (range->filename()) { drop(stack, 1); const auto& msg = stack.back().toStringRef(); if (need_warn) { auto line = range->starting_line_no() + range->lineno_for_offset(node->sourceRange().start()); c10::SourceLocation location{ "", range->filename()->c_str(), uint32_t(line)}; // Sends the warning to the warning handler with the // "verbatim" flag. This flag ensures the warning handler // will print the exception as configured. c10::Warning::warn(location, msg, /*verbatim=*/true); } stack.pop_back(); } else { const auto& msg = stack.back().toStringRef(); if (need_warn) { TORCH_WARN(msg); } stack.pop_back(); } } INST_NEXT; } } } catch (std::exception& e) { for (auto it = entered_objects.rbegin(), end = entered_objects.rend(); it != end; ++it) { auto& f = it->toObject()->type()->getMethod("__exit__"); Stack stack; push(stack, *it); push(stack, IValue()); push(stack, IValue()); push(stack, IValue()); try { f.run(stack); } catch (std::exception& _) { // TODO(T98048876): Handle `_` correctly. } } if (FLAGS_torch_jit_enable_rethrow_caught_exception) { if (future_) { future_->setError(std::current_exception()); return false; } throw; } auto* jit_exception = dynamic_cast(&e); // Janky af. See https://github.com/pytorch/pytorch/issues/54612 auto* not_implemented_error = dynamic_cast(&e); c10::optional python_class_name; if (jit_exception) { python_class_name = jit_exception->getPythonClassName(); } handleError( e, (bool)jit_exception, not_implemented_error, python_class_name); return false; } } #undef INST_NEXT #undef INST_DISPATCH #undef INST #undef INST_GUARD #undef INST_FETCH #undef JIT_USE_COMPUTED_GOTO void formatStackTrace(std::ostream& out) { format_stack_trace(out, callstack()); } void handleError( const std::exception& e, bool is_jit_exception, c10::NotImplementedError* not_implemented_error, c10::optional python_class_name) { ExceptionMessage msg(e); std::ostringstream ss; std::string class_name = python_class_name ? *python_class_name : "RuntimeError"; ss << "The following operation failed in the TorchScript interpreter.\n"; formatStackTrace(ss); ss << class_name << ": " << msg << "\n"; if (future_) { future_->setError(std::make_exception_ptr(Future::FutureError(ss.str()))); } else if (is_jit_exception) { // save the original exception's message when creating a new JITException throw JITException(ss.str(), python_class_name, e.what()); } else if (not_implemented_error) { throw c10::NotImplementedError( ss.str(), not_implemented_error->backtrace(), not_implemented_error->caller()); } else { if (get_cpp_stacktraces_enabled()) { ss << e.what() << "\n"; } throw std::runtime_error(ss.str()); } } static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { if (!frame.record_function) { auto step_callbacks = at::getStepCallbacksUnlessEmpty( at::RecordScope::TORCHSCRIPT_FUNCTION); if (C10_UNLIKELY(step_callbacks.has_value())) { auto rec_fn = std::make_unique(std::move(*step_callbacks)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive()); if (rec_fn->needsInputs()) { rec_fn->before( frame.function->function_name_, last(stack, frame.function->n_inputs)); } else { rec_fn->before(frame.function->function_name_); } frame.record_function = std::move(rec_fn); } } } public: // One way to avoid overhead of forming string would be to return // a vector of frame.function, i.e. CodeImpl* // This is not exactly clean as it will expose, internal details of // interpreter. But this way we hold onto graph/node and Function and // we can create module hierarchy string for each event in autograd // profiler at the end, when consolidating events. // At the moment overhead does not seem exhorbitantly large. // Another option would be return vector of (string, InlinedCallstackPtrs) // string would contain function name and typename of self // Format of the returned vector of strings: // For each frame, the corresponding module name, type and function name // are in following format: // (module type):: // Special keys for module-instance-name: // - TOP: for top level module // - SELF: When method/function of the frame is associated with // previous frame's module instance // - INSTANCE_NAME_UNKNOWN: instance name cannot be figured out // - CALL_FUNCTION: call to free function std::vector moduleHierarchy() const { std::vector module_function_list; std::string module_hierarchy("TOP"); for (size_t i = 0; i < frames.size(); ++i) { const Frame& frame = frames[i]; std::string fn_name = frame.function->function_name_; // For each frame, type of the class with which the function is // associated, is queried here. And the type name is added to // module hierarchy. const auto& g = frame.function->graph_; std::string g_self_type; if (g && g->inputs().size() > 0) { const auto& g_self_type_ptr = g->inputs()[0]->type()->cast(); if (g_self_type_ptr) { g_self_type = g_self_type_ptr->name()->qualifiedName(); g_self_type = g_self_type.substr(g_self_type.find_last_of('.') + 1); } } module_hierarchy.append("(") .append(g_self_type) .append(")::") .append(fn_name); module_function_list.emplace_back(std::move(module_hierarchy)); size_t pc = frame.pc; // CALL nodes have already advanced the pc, so // undo that to report the call node if (i + 1 < frames.size()) { --pc; } Node* node = frame.function->instructions_source_[pc]; if (node->callstack()) { for (const auto& p : (*node->callstack())->vec()) { fn_name = std::get<0>(p)->name(); const auto& opt_module_info = std::get<2>(p); if (opt_module_info.has_value()) { const auto& module_instance_info = opt_module_info.value(); module_hierarchy = utils::get_module_info(module_instance_info); module_hierarchy.append("::").append(fn_name); } else { // This is likely a call to free function, not associated with // any class module_hierarchy = "::"; module_hierarchy.append(fn_name); } module_function_list.emplace_back(std::move(module_hierarchy)); } } module_hierarchy = std::string(); // If this node is of type callMethod then the following frame // will contain the op being executed. // For such callMethod node, we add the object instance name // associated with it, since the following frame will not have it. if (node->kind() == prim::CallMethod) { std::string class_instance_name; if (node->input(0)->node()->kind() == prim::GetAttr) { class_instance_name = node->input(0)->node()->s(attr::name); } else if ( node->owningGraph()->inputs().size() > 0 && node->input(0) == node->owningGraph()->inputs()[0]) { class_instance_name = "SELF"; } else { class_instance_name = "INSTANCE_NAME_UNKNOWN"; } module_hierarchy = std::move(class_instance_name); } else if (node->kind() == prim::CallFunction) { auto function_constant = node->input(0)->node(); auto fun_type = function_constant->output()->type()->expect(); auto fun_name = fun_type->function()->name(); module_hierarchy = "CALL_FUNCTION::"; module_hierarchy.append(fun_name); } } return module_function_list; } std::vector callstack() const { std::vector entries; for (const auto i : c10::irange(frames.size())) { const Frame& frame = frames[i]; std::string previous_fn_name = frame.function->function_name_; size_t pc = frame.pc; // CALL nodes have already advanced the pc, so // undo that to report the call node if (i + 1 < frames.size()) { --pc; } Node* node = frame.function->instructions_source_[pc]; if (node->callstack()) { for (const auto& p : (*node->callstack())->vec()) { entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)}); previous_fn_name = std::get<0>(p)->name(); } } entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()}); } return entries; } c10::intrusive_ptr getOrCreateFuture() { if (!future_) { future_ = c10::make_intrusive(frames.front().function->return_type_); } return future_; } c10::intrusive_ptr runAsync(Stack& stack) { getOrCreateFuture(); runImpl(stack); return future_; } void run(Stack& stack) { // By the time the continuation completes the frame will be gone, so this // must be done before calling runImpl(). TORCH_INTERNAL_ASSERT(!frames.empty()); const auto num_outputs = frames.front().function->n_outputs; if (runImpl(stack)) { future_->wait(); if (num_outputs == 1) { push(stack, future_->value()); } else { auto tuple = future_->value().toTuple(); for (const IValue& value : tuple->elements()) { push(stack, value); } } } } }; std::vector currentCallstack() { if (tls_int_state_ptr_) { auto cs = tls_int_state_ptr_->callstack(); std::reverse(cs.begin(), cs.end()); return cs; } return std::vector(); } std::vector currentModuleHierarchy() { if (tls_int_state_ptr_) { return tls_int_state_ptr_->moduleHierarchy(); } return std::vector(); } std::ostream& operator<<(std::ostream& out, const Code& code) { out << *code.pImpl->graph_ << "\n"; code.pImpl->dump(out); return out; } Code::Code( const std::shared_ptr& graph, std::string function_name, size_t remaining_bailout_depth) : pImpl(new CodeImpl( graph, std::move(function_name), remaining_bailout_depth)) {} Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {} Code::~Code() = default; MobileCode::MobileCode( const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions, bool support_default_args_before_out, bool emit_promoted_ops, size_t remaining_bailout_depth) : Code(new interpreter::MobileCodeImpl( graph, std::move(function_name), emit_default_input_instructions, support_default_args_before_out, emit_promoted_ops, remaining_bailout_depth)) {} MobileCode::~MobileCode() = default; const std::vector& Code::grad_executors() { return pImpl->grad_executors(); } const std::vector& Code::diff_graph_op_executors() { return pImpl->diff_graph_op_executors(); } size_t Code::num_bailouts() const { return pImpl->type_table_.size(); } void Code::request_bailout(size_t index) { pImpl->request_bailout(index); } size_t Code::num_inputs() const { return pImpl->n_inputs; } size_t Code::num_outputs() const { return pImpl->n_outputs; } const std::vector& Code::constant_table() const { return pImpl->constant_table(); } const std::vector& Code::instructions() const { return pImpl->instructions(); } const std::unordered_map& Code::op_to_num_specified_args() const { return pImpl->op_to_num_specified_args(); } const std::vector& Code::instructions_source() const { return pImpl->instructions_source(); } const std::vector& Code::type_table() const { return pImpl->type_table_; } size_t Code::register_size() const { return pImpl->register_size_; } InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher) : pImpl(c10::make_intrusive( code, std::move(taskLauncher))) {} InterpreterState::~InterpreterState() = default; void InterpreterState::run(Stack& stack) { static_cast(pImpl.get())->run(stack); } c10::intrusive_ptr InterpreterState::runAsync(Stack& stack) { return static_cast(pImpl.get())->runAsync(stack); } c10::intrusive_ptr InterpreterState::getFuture() { return static_cast(pImpl.get())->getOrCreateFuture(); } InterpreterState::InterpreterState( c10::intrusive_ptr pImpl_) : pImpl(std::move(pImpl_)) {} void InterpreterContinuation::operator()() { #ifdef USE_RPC auto prev_dist_id = DistAutogradContainer::currentContextId(); DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_); #endif if (tls_state_ != c10::nullopt) { at::ThreadLocalStateGuard g(*tls_state_); state.runAsync(stack); } else { state.runAsync(stack); } #ifdef USE_RPC DistAutogradContainer::forceCurrentContextId(prev_dist_id); #endif } } // namespace jit } // namespace torch