mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[PyTorch][Mobile] Insert the module name as name()
to metadata dict if metadata doesn't contain "model_name" (#44400)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44400
This diff does the identical thing as D23549149 (398409f072
) does. A fix included for OSS CI: pytorch_windows_vs2019_py36_cuda10.1_test1
ghstack-source-id: 111679745
Test Plan:
- CI
- OSS CI
Reviewed By: xcheng16
Differential Revision: D23601050
fbshipit-source-id: 8ebdcd8fdc5865078889b54b0baeb397a90ddc40
This commit is contained in:
committed by
Facebook GitHub Bot
parent
24efd29d19
commit
a00d36b0e7
@ -26,17 +26,19 @@ Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
|
||||
|
||||
c10::IValue Module::run_method(const std::string& method_name, Stack stack) {
|
||||
auto observer = torch::observerConfig().getModuleObserver();
|
||||
auto module_metadata = metadata();
|
||||
/* if the metadata dict doesn't contain "model_name", copy the metadata and
|
||||
set the value of "model_name" as name() */
|
||||
std::unordered_map<std::string, std::string> copied_metadata = metadata();
|
||||
if (metadata().find("model_name") == metadata().end()) {
|
||||
copied_metadata["model_name"] = name();
|
||||
}
|
||||
if (observer) {
|
||||
observer->onEnterRunMethod(module_metadata, method_name);
|
||||
observer->onEnterRunMethod(copied_metadata, method_name);
|
||||
}
|
||||
|
||||
auto debug_info = std::make_shared<MobileDebugInfo>();
|
||||
if (module_metadata.find("model_name") != module_metadata.end()) {
|
||||
debug_info->setModelName(module_metadata.at("model_name"));
|
||||
} else {
|
||||
debug_info->setModelName(name());
|
||||
}
|
||||
std::string name = copied_metadata["model_name"];
|
||||
debug_info->setModelName(name);
|
||||
debug_info->setMethodName(method_name);
|
||||
at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info);
|
||||
|
||||
|
Reference in New Issue
Block a user