#include #include #include #include #include #include #include #include #include #include namespace torch::jit { std::ostream& operator<<(std::ostream& out, Instruction inst); namespace mobile { void CompilationUnit::register_function(std::unique_ptr fn) { methods_.emplace_back(std::move(fn)); } const Function* CompilationUnit::find_function( const c10::QualifiedName& qn) const { for (auto& fn : methods_) { if (fn->qualname() == qn) { return fn.get(); } } return nullptr; } Function* CompilationUnit::find_function(const c10::QualifiedName& qn) { // NOLINTNEXTLINE return const_cast( static_cast(this)->find_function(qn)); } Method Module::get_method(const std::string& name) const { if (auto method = find_method(name)) { return *method; } AT_ERROR("Method '", name, "' is not defined."); } bool Module::compareMethodSchemas( const std::string& name_1, const std::string& name_2) { std::optional schema_1, schema_2; for (const auto& fn : cu_->methods()) { if (fn->name() == name_1) { schema_1 = fn->getSchema(); } if (fn->name() == name_2) { schema_2 = fn->getSchema(); } } if (schema_1.has_value() && schema_2.has_value()) { return (schema_1 == schema_2); } return false; } void Module::unsafeRemoveMethod(const std::string& basename) { int64_t i = 0; for (; i < static_cast(cu_->methods().size()); ++i) { if ((cu_->methods()[i])->name() == basename) { break; } } object_->type()->unsafeRemoveMethod(basename); cu_->unsafeRemoveFunction(i); } void Module::unsafeCopyMethod( const std::string& new_method_name, const Function& to_be_copied) { TORCH_CHECK( !find_method(new_method_name).has_value(), "Trying to replace existing method."); const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname(); c10::QualifiedName qualified_method_name( tobe_copied_name.prefix(), new_method_name); std::unique_ptr new_fn = std::make_unique( qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema()); object_->type()->addMethod(new_fn.get()); cu_->register_function(std::move(new_fn)); } std::optional Module::find_method(const std::string& basename) const { for (const auto& fn : cu_->methods()) { if (fn->name() == basename) { return std::make_optional(Method(this, fn.get())); } } return std::nullopt; } namespace { // For JIT, there is a private function to get all modules by iteration in // struct slot_iterator_impl (jit/api/module.h). The following function use // recursion to mimic the logic without allocating extra memory to get module // list and set training attribute directly. void set_train_recurse( const c10::intrusive_ptr& obj, bool on) { if (auto slot = obj->type()->findAttributeSlot("training")) { obj->setSlot(*slot, on); } else { TORCH_INTERNAL_ASSERT( false, "'training' attribute not found. Did you accidentally " "call .eval() before saving your model?"); } for (const auto& slot : obj->slots()) { // slots is a list of IValue. Continue setting training attribute only // if the slot is an object and a module. if (slot.isObject() && slot.toObjectRef().type()->is_module()) { set_train_recurse(slot.toObject(), on); } } } void slot_params_recurse( const c10::intrusive_ptr& obj, std::vector* params) { for (const auto& slot : obj->slots()) { if (slot.isTensor()) { params->emplace_back(slot.toTensor()); } else if (slot.isObject()) { slot_params_recurse(slot.toObject(), params); } } } void slot_named_params_recurse( const c10::intrusive_ptr& obj, std::map* params, const std::string& parent_name) { auto slots = obj->slots(); size_t nslots = slots.size(); for (const auto i : c10::irange(nslots)) { auto slot = slots[i]; std::string name = parent_name.empty() ? parent_name : parent_name + "."; name += obj->type()->getAttributeName(i); // TODO: Fix this filter. Requires_grad is not the appropriate // filter of a parameter, but is a temporary hack to help probable // users of this api. The correct behavior is to filter by the // obj->type->is_parameter() but this currently always returns // false on mobile. if (slot.isTensor() && slot.toTensor().requires_grad()) { (*params)[name] = slot.toTensor(); } else if (slot.isObject()) { slot_named_params_recurse(slot.toObject(), params, name); } } } #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) std::string getTopModuleTypeName(const Module& m) { std::string name; if (m._ivalue()->type() && m._ivalue()->type()->name()) { name = m._ivalue()->type()->name().value().name(); } return name; } #endif } // namespace const std::vector Module::parameters() const { std::vector params; slot_params_recurse(object_, ¶ms); return params; } // Returns a mapping for all attributes that requires_grad=True in a module. // This behavior differs from full torch script modules. This is a bug, // but currently there is no way to correctly label parameters in the // loading of a mobile module. TODO const std::map Module::named_parameters() const { std::map params; const std::string name = ""; slot_named_params_recurse(object_, ¶ms, name); return params; } std::string Module::getModuleHierarchy(const int64_t debug_handle) const { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) return getDebugTable().getModuleHierarchyInfo( debug_handle, getTopModuleTypeName(*this)); #else return ""; #endif } std::string Module::getCallStack(const int64_t debug_handle) const { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) return getDebugTable().getSourceDebugString( debug_handle, getTopModuleTypeName(*this)); #else return ""; #endif } // We will continue to support this API for now as this is being relied upon // for profiling. // We really need to change this part, so in the next step for profiling support // for delegates, the first thing will be to rewrite how profiling is done // for lite interpreter. std::string Module::get_forward_method_debug_info(int64_t debug_handle) const { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) return getDebugTable().getModuleHierarchyInfo( debug_handle, getTopModuleTypeName(*this)); #else return ""; #endif } void Module::train(bool on) { set_train_recurse(object_, on); } bool Module::is_training() const { if (auto slot = object_->type()->findAttributeSlot("training")) { return object_->getSlot(*slot).toBool(); } return true; } const std::vector Module::get_methods() const { std::vector methods; for (std::unique_ptr& fn : cu_->methods()) { methods.emplace_back(this, fn.get()); } return methods; } Method::Method(const Module* owner, Function* function) : owner_(owner), function_(function) {} void Method::run(Stack& stack) const { auto observer = torch::observerConfig().getModuleObserver(); // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) auto instance_key = std::rand(); /* if the metadata dict doesn't contain "model_name", copy the metadata and set the value of "model_name" as name() */ std::unordered_map copied_metadata = owner_->getMetadata(); if (observer) { observer->onEnterRunMethod(instance_key); } auto debug_info = std::make_shared(); std::string name = copied_metadata["model_name"]; debug_info->setModelName(name); debug_info->setMethodName(function_->name()); at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info); std::string error_message; auto failure_guard = c10::make_scope_exit([&]() { if (!observer) { return; } #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) if (error_message.empty()) { error_message = owner_->getDebugTable().getSourceDebugString( function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_)); } #endif observer->onFailRunMethod( copied_metadata, function_->name(), instance_key, error_message.empty() ? "Unknown exception" : error_message.c_str()); }); try { stack.insert(stack.begin(), owner_->_ivalue()); // self function_->run(stack); if (observer) { observer->onExitRunMethod( copied_metadata, function_->name(), instance_key); } failure_guard.release(); // This exception must be caught first as it derived from c10::Error } catch (c10::BackendRuntimeException& e) { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) for (auto handle : function_->getExceptionDebugHandles()) { e.pushDebugHandle(handle); } // symbolicate all handles auto debug_string = owner_->getDebugTable().getSourceDebugString( e.getDebugHandles(), getTopModuleTypeName(*owner_)); e.add_context(debug_string); #endif error_message = e.what(); TORCH_RETHROW(e); } catch (c10::Error& error) { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) auto debug_string = owner_->getDebugTable().getSourceDebugString( function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_)); error.add_context(debug_string); #endif error_message = error.what(); TORCH_RETHROW(error); } } c10::IValue Method::operator()(std::vector stack) const { run(stack); TORCH_INTERNAL_ASSERT(!stack.empty()); return stack.front(); } static std::optional print_type(const c10::Type& t) { auto namedType = t.cast(); if (namedType && namedType->name()) { return namedType->name().value().qualifiedName(); } if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } return std::nullopt; } TORCH_API ModuleInfo get_module_info(const mobile::Module& module) { ModuleInfo minfo; minfo.operator_version = module.min_operator_version(); minfo.bytecode_version = module.bytecode_version(); std::vector type_name_list; for (const auto& func_ptr : module.compilation_unit().methods()) { const auto& function = *func_ptr; for (const auto i : c10::irange(function.get_code().op_names_.size())) { const auto& op = function.get_code().op_names_[i]; minfo.opname_to_num_args[mobile::operator_str(op)] = function.get_code().operator_input_sizes_[i]; } for (const c10::TypePtr& tp : function.get_code().types_) { type_name_list.push_back(tp->annotation_str(print_type)); } minfo.function_names.insert(function.qualname().qualifiedName()); } c10::TypeParser parser(type_name_list); parser.parseList(); minfo.type_names = parser.getContainedTypes(); return minfo; } } // namespace mobile } // namespace torch::jit