#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { std::ostream& operator<<(std::ostream& out, Instruction inst); namespace mobile { InterpreterState::InterpreterState(const Code& code) { enterFrame(code); } namespace { static thread_local std::vector exception_debug_handles_; void createObject(Stack& stack, const at::ClassTypePtr& type) { auto userObj = c10::ivalue::Object::create( c10::StrongTypePtr(type->compilation_unit(), type), type->numAttributes()); push(stack, std::move(userObj)); } void isinstance(Stack& stack, at::ArrayRef types) { at::TypePtr ty = pop(stack).type(); for (const at::TypePtr& candidate : types) { if (ty->isSubtypeOf(*candidate)) { push(stack, true); return; } } push(stack, false); } } // namespace using namespace at; const std::vector& getInterpretersExceptionDebugHandles() { return exception_debug_handles_; } void InterpreterState::enterFrame(const Code& code) { frames_.emplace_back(code); registers_.resize(registers_.size() + code.register_size_); } void InterpreterState::leaveFrame() { registers_.resize( registers_.size() - frames_.back().getCode().register_size_); frames_.pop_back(); } void InterpreterState::saveExceptionDebugHandles() { std::vector exception_debug_handles; for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) { size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0); if (auto handle = frame->getDebugHandle(pc)) { exception_debug_handles.push_back(*handle); } else { exception_debug_handles.push_back(-1); } } exception_debug_handles_ = std::move(exception_debug_handles); } void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) { bool newFrame = f.call(stack, [&](const mobile::Code& code) { enterFrame(code); }); (frames_.rbegin() + (newFrame ? 1 : 0))->step(); } bool InterpreterState::run(Stack& stack) { while (true) { try { auto& frame = frames_.back(); const auto& code = frame.getCode(); const auto pc = frame.getPC(); auto inst = frame.getInstruction(); // If no valid debug handle found then just log pc. // This is possible when we did not save debug handles DebugHandle debug_handle = pc; if (auto handle = frame.getDebugHandle()) { debug_handle = *handle; } // std::cout << "RUNNING " << pc << " " << code.instructions_[pc]; // if (inst.op == OP) { // std::cout << ", " << code.op_names_[inst.X].name; // if (!code.op_names_[inst.X].overload_name.empty()) { // std::cout << "." << code.op_names_[inst.X].overload_name; // } // } // std::cout << std::endl; // TODO(iliacher): remove the workaround after RecordFunction is in // Dispatcher // Check with iliacher if has been done. // Plus this is not safe as if you throw exception record function will be // left enabled. That is a TODO // NOTE: this recordFunction logic takes up ~2-3% of cpu cycles in some // workflows. do we need it and/or can we opt-out of // isRecordFunctionEnabled with a macro? if we delete it, things appear to // work just fine. bool prev_value = isRecordFunctionEnabled(); if (!prev_value) { // enable only for the RecordFunction enableRecordFunction(true); } switch (inst.op) { case OP: { if (at::hasGlobalCallbacks()) { if (auto* mobile_debug_info = static_cast( c10::ThreadLocalDebugInfo::get( c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) { mobile_debug_info->setOpIdx(pc); } } if (inst.X < 0 || static_cast(inst.X) >= code.op_names_.size() || static_cast(inst.X) >= code.operators_.size()) { TORCH_CHECK(false, "Can't load op with index: ", inst.X); } RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( code.op_names_[inst.X].name, debug_handle, stack); code.operators_[inst.X](stack); frame.step(); } break; case OPN: { if (inst.X < 0 || static_cast(inst.X) >= code.op_names_.size() || static_cast(inst.X) >= code.operators_.size()) { TORCH_CHECK(false, "Can't load op with index: ", inst.X); } stack.emplace_back(inst.N); RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( code.op_names_[inst.X].name, debug_handle, stack); code.operators_[inst.X](stack); frame.step(); } break; case CALL: { auto& function = *frame.getCode().functions_.at(inst.X); callFunction(function, stack); } break; case INTERFACE_CALL: { if (inst.X < 0 || static_cast(inst.X) >= code.constants_.size()) { TORCH_CHECK(false, "Can't load constant with index: ", inst.X); } if (inst.N == 0 || inst.N > stack.size()) { TORCH_CHECK( false, "INTERFACE_CALL N=", inst.N, " not in range [1, ", stack.size(), "]"); } torch::jit::Function& method = peek(stack, 0, inst.N) .toObject() ->type() ->getMethod(code.constants_[inst.X].toStringRef()); RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( method.name(), debug_handle, stack); callFunction(method, stack); } break; case LOAD: stack.emplace_back(reg(inst.X)); frame.step(); break; case MOVE: stack.emplace_back(std::move(reg(inst.X))); frame.step(); break; case STORE: reg(inst.X) = pop(stack); frame.step(); break; case STOREN: for (size_t i = inst.N; i > 0; --i) { reg(inst.X + i - 1) = pop(stack); } frame.step(); break; case DROP: pop(stack); frame.step(); break; case DROPR: reg(inst.X) = IValue(); frame.step(); break; case LOADC: if (inst.X < 0 || static_cast(inst.X) >= code.constants_.size()) { TORCH_CHECK(false, "Can't load constant with index: ", inst.X); } stack.emplace_back(code.constants_[inst.X]); frame.step(); break; case GET_ATTR: { auto userObj = pop(stack).toObject(); auto value = userObj->getSlot(inst.X); push(stack, std::move(value)); frame.step(); } break; case SET_ATTR: { auto v = pop(stack); auto userObj = pop(stack).toObject(); // Mobile only: since the number of slots is not known, resize the // numAttributes before setSlot. while (static_cast(userObj->type()->numAttributes()) <= inst.X) { std::stringstream ss; ss << userObj->type()->numAttributes(); userObj->type()->addAttribute(ss.str(), c10::NoneType::get()); } userObj->setSlot(inst.X, std::move(v)); frame.step(); } break; case JF: frame.jump(pop(stack).toBool() ? 1 : inst.X); break; case JMP: frame.jump(inst.X); break; case LOOP: { // stack: iteration_count, max_iter, cond, loop_carried_deps... auto sframe = stack.end() - (inst.N + 1); int64_t trip_count = sframe[0].toInt(); int64_t max_trip_count = sframe[1].toInt(); bool cond = sframe[2].toBool(); if (trip_count < max_trip_count && cond) { sframe[2] = trip_count; sframe[0] = trip_count + 1; frame.step(); } else { size_t n_loop_carried = inst.N - 2; for (const auto i : c10::irange(n_loop_carried)) { sframe[i] = std::move(sframe[i + 3]); } drop(stack, 3); // iteration_count, max_iter, cond frame.jump(inst.X); } } break; case RET: leaveFrame(); if (!frames_.empty()) { continue; } return false; case LIST_CONSTRUCT: { listConstruct(stack, *code.types_.at(inst.X), inst.N); frame.step(); } break; case LIST_UNPACK: { listUnpack(stack, inst.X); frame.step(); } break; case TUPLE_CONSTRUCT: { tupleConstruct(stack, inst.X); frame.step(); } break; case TUPLE_SLICE: { tupleSlice(stack, inst.X, inst.X + inst.N); frame.step(); } break; case TUPLE_INDEX: { tupleIndex(stack); frame.step(); } break; case RAISE_EXCEPTION: { raiseExceptionWithMessage(stack); frame.step(); } break; case __IS__: { is(stack); frame.step(); } break; case UN_INITIALIZED: { unInitialized(stack); frame.step(); } break; case __ISNOT__: { isNot(stack); frame.step(); } break; case FORMAT: { format(stack, inst.X); frame.step(); } break; case DEVICE: { device(stack); frame.step(); } break; case DTYPE: { dtype(stack); frame.step(); } break; case DIM: { dim(stack); frame.step(); } break; case __NOT__: { _not(stack); frame.step(); } break; case DICT_INDEX: { dictIndex(stack); frame.step(); } break; case TO_LIST: { toList(stack); frame.step(); } break; case NUM_TO_TENSOR: { numToTensorScalar(stack); frame.step(); } break; case IS_CUDA: { isCuda(stack); frame.step(); } break; case DICT_CONSTRUCT: { dictConstruct(stack, *code.types_.at(inst.X), inst.N); frame.step(); } break; case NAMED_TUPLE_CONSTRUCT: { namedTupleConstruct(stack, code.types_.at(inst.X), inst.N); frame.step(); } break; case CREATE_OBJECT: { auto type = code.types_.at(inst.X)->expect(); createObject(stack, type); frame.step(); } break; case ISINSTANCE: { at::ArrayRef types(&code.types_.at(inst.X), inst.N); isinstance(stack, types); frame.step(); } break; case WARN: { drop(stack, 1); // Note: Please don't move the pop(stack) code below into the // TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments // when STRIP_ERROR_MESSAGES is defined (which happens for production // mobile builds). This will cause the stack to be in an inconsistent // state. It has previously resulted in a SEV (S22350). TORCH_WARN(stack.back().toStringRef()); stack.pop_back(); frame.step(); } break; default: TORCH_CHECK(false, toString(inst.op), " is invalid."); } if (!prev_value) { enableRecordFunction(false); } // This exception must be caught first as it derived from c10::Error } catch (c10::BackendRuntimeException& e) { saveExceptionDebugHandles(); TORCH_RETHROW(e); } catch (c10::Error& error) { // Reason for catching and rethrowing the error is so that we can // set the exception pc that is queried later saveExceptionDebugHandles(); TORCH_RETHROW(error); } catch (...) { saveExceptionDebugHandles(); throw; } // for (auto val : stack) { // if (val.isTensor()) { // std::cout << val.toTensor().sizes() << std::endl; // } else { // std::cout << val << std::endl; // } // } } return false; } IValue& InterpreterState::reg(size_t reg) { TORCH_CHECK( reg > 0 && reg <= registers_.size(), "Invalid register index: ", reg); return *(registers_.end() - reg); } } // namespace mobile } // namespace torch::jit