mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Subscribe for record function and if android do atrace (#28708)
Summary: ghstack-source-id: 5edaf471557c25098ca0547229f2763760866887 Pull Request resolved: https://github.com/pytorch/pytorch/pull/28708 Some cpp formatting changes as I run `clang-format -i` Testing on devserver: make assets (models): ``` pushd android/test_app/; python make_assets.py; popd ``` Build test_app apk: ``` TRACE_ENABLED=1 sh android/build_test_app.sh find . -type f -name *apk ./android/test_app/app/build/outputs/apk/mobNet2Quant/debug/test_app-mobNet2Quant-debug.apk ./android/test_app/app/build/outputs/apk/resnet18/debug/test_app-resnet18-debug.apk ``` Install apk: `adb install -r test_app-mobNet2Quant-debug.apk` Run app on the device. Systrace: ``` $ANDROID_HOME/platform-tools/systrace/systrace.py -t 10 -a org.pytorch.testapp.mobNet2Quant sched freq idle am wm gfx view binder_driver hal dalvik camera input res -o trace.html ``` trace.html contains sections like `jni::Module::forward`  Pull Request resolved: https://github.com/pytorch/pytorch/pull/28712 Differential Revision: D18495898 Pulled By: IvanKobzarev fbshipit-source-id: 0bced4a442f9dd90525520972a2c1f5d51f57df3
This commit is contained in:
committed by
Facebook Github Bot
parent
a68c52494c
commit
aa6e992ffb
@ -47,7 +47,7 @@ echo "ABIS_LIST:$ABIS_LIST"
|
|||||||
LIB_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs
|
LIB_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs
|
||||||
INCLUDE_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include
|
INCLUDE_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include
|
||||||
mkdir -p $LIB_DIR
|
mkdir -p $LIB_DIR
|
||||||
rm $LIB_DIR/*
|
rm -f $LIB_DIR/*
|
||||||
mkdir -p $INCLUDE_DIR
|
mkdir -p $INCLUDE_DIR
|
||||||
|
|
||||||
for abi in $(echo $ABIS_LIST | tr ',' '\n')
|
for abi in $(echo $ABIS_LIST | tr ',' '\n')
|
||||||
|
@ -3,6 +3,17 @@ project(pytorch_jni CXX)
|
|||||||
set(CMAKE_CXX_STANDARD 11)
|
set(CMAKE_CXX_STANDARD 11)
|
||||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||||
|
|
||||||
|
set(TRACE_ENABLED OFF)
|
||||||
|
if(DEFINED ENV{TRACE_ENABLED})
|
||||||
|
if($ENV{TRACE_ENABLED} STREQUAL "1")
|
||||||
|
message(STATUS "TRACE_ENABLED ON")
|
||||||
|
set(TRACE_ENABLED ON)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
if(NOT TRACE_ENABLED)
|
||||||
|
message(STATUS "TRACE_ENABLED OFF")
|
||||||
|
endif()
|
||||||
|
|
||||||
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
|
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
|
||||||
|
|
||||||
if (ANDROID_ABI)
|
if (ANDROID_ABI)
|
||||||
@ -20,6 +31,10 @@ endif()
|
|||||||
|
|
||||||
message(STATUS "libtorch dir:${libtorch_DIR}")
|
message(STATUS "libtorch dir:${libtorch_DIR}")
|
||||||
|
|
||||||
|
configure_file(
|
||||||
|
${pytorch_android_DIR}/cmake_macros.h.in
|
||||||
|
${pytorch_android_DIR}/cmake_macros.h)
|
||||||
|
|
||||||
file(GLOB pytorch_android_SOURCES
|
file(GLOB pytorch_android_SOURCES
|
||||||
${pytorch_android_DIR}/pytorch_jni_jit.cpp
|
${pytorch_android_DIR}/pytorch_jni_jit.cpp
|
||||||
${pytorch_android_DIR}/pytorch_jni_common.cpp
|
${pytorch_android_DIR}/pytorch_jni_common.cpp
|
||||||
|
3
android/pytorch_android/src/main/cpp/cmake_macros.h
Normal file
3
android/pytorch_android/src/main/cpp/cmake_macros.h
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
/* #undef TRACE_ENABLED */
|
3
android/pytorch_android/src/main/cpp/cmake_macros.h.in
Normal file
3
android/pytorch_android/src/main/cpp/cmake_macros.h.in
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#cmakedefine TRACE_ENABLED
|
@ -10,6 +10,25 @@
|
|||||||
|
|
||||||
namespace pytorch_jni {
|
namespace pytorch_jni {
|
||||||
|
|
||||||
|
bool Trace::is_initialized_ = false;
|
||||||
|
|
||||||
|
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||||
|
Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
|
||||||
|
Trace::fp_ATrace_endSection Trace::ATrace_endSection;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void Trace::init() {
|
||||||
|
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||||
|
void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
|
||||||
|
if (lib != NULL) {
|
||||||
|
Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
|
||||||
|
dlsym(lib, "ATrace_beginSection"));
|
||||||
|
Trace::ATrace_endSection =
|
||||||
|
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: Codes must be kept in sync with DType.java.
|
// NOTE: Codes must be kept in sync with DType.java.
|
||||||
// NOTE: Never serialize these, because they can change between releases.
|
// NOTE: Never serialize these, because they can change between releases.
|
||||||
constexpr static int kTensorDTypeUInt8 = 1;
|
constexpr static int kTensorDTypeUInt8 = 1;
|
||||||
@ -186,6 +205,7 @@ public:
|
|||||||
|
|
||||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||||
const at::IValue& ivalue) {
|
const at::IValue& ivalue) {
|
||||||
|
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
|
||||||
if (ivalue.isNone()) {
|
if (ivalue.isNone()) {
|
||||||
static auto jMethodOptionalNull =
|
static auto jMethodOptionalNull =
|
||||||
JIValue::javaClassStatic()
|
JIValue::javaClassStatic()
|
||||||
@ -375,6 +395,7 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
|||||||
|
|
||||||
at::IValue JIValue::JIValueToAtIValue(
|
at::IValue JIValue::JIValueToAtIValue(
|
||||||
facebook::jni::alias_ref<JIValue> jivalue) {
|
facebook::jni::alias_ref<JIValue> jivalue) {
|
||||||
|
Trace _s{"jni::JIValue::JIValueToAtIValue"};
|
||||||
static const auto typeCodeField =
|
static const auto typeCodeField =
|
||||||
JIValue::javaClassStatic()->getField<jint>("mTypeCode");
|
JIValue::javaClassStatic()->getField<jint>("mTypeCode");
|
||||||
const auto typeCode = jivalue->getFieldValue(typeCodeField);
|
const auto typeCode = jivalue->getFieldValue(typeCodeField);
|
||||||
|
@ -1,9 +1,66 @@
|
|||||||
#include <fbjni/fbjni.h>
|
#include <fbjni/fbjni.h>
|
||||||
#include <torch/csrc/api/include/torch/types.h>
|
#include <torch/csrc/api/include/torch/types.h>
|
||||||
|
|
||||||
|
#include "cmake_macros.h"
|
||||||
|
|
||||||
|
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||||
|
#include <android/log.h>
|
||||||
|
|
||||||
|
#include <android/trace.h>
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
|
#define ALOGI(...) \
|
||||||
|
__android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace pytorch_jni {
|
namespace pytorch_jni {
|
||||||
|
|
||||||
|
class Trace {
|
||||||
|
public:
|
||||||
|
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||||
|
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
|
||||||
|
typedef void* (*fp_ATrace_endSection)(void);
|
||||||
|
|
||||||
|
static fp_ATrace_beginSection ATrace_beginSection;
|
||||||
|
static fp_ATrace_endSection ATrace_endSection;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static void ensureInit() {
|
||||||
|
if (!Trace::is_initialized_) {
|
||||||
|
init();
|
||||||
|
Trace::is_initialized_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void beginSection(const char* name) {
|
||||||
|
Trace::ensureInit();
|
||||||
|
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||||
|
ATrace_beginSection(name);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static void endSection() {
|
||||||
|
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
|
||||||
|
ATrace_endSection();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
Trace(const char* name) {
|
||||||
|
ensureInit();
|
||||||
|
beginSection(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
~Trace() {
|
||||||
|
endSection();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void init();
|
||||||
|
static bool is_initialized_;
|
||||||
|
};
|
||||||
|
|
||||||
class JIValue : public facebook::jni::JavaClass<JIValue> {
|
class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||||
public:
|
public:
|
||||||
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
|
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
|
||||||
|
|
||||||
constexpr static int kTypeCodeNull = 1;
|
constexpr static int kTypeCodeNull = 1;
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include <fbjni/ByteBuffer.h>
|
#include <fbjni/ByteBuffer.h>
|
||||||
#include <fbjni/fbjni.h>
|
#include <fbjni/fbjni.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/autograd/record_function.h>
|
||||||
#include <torch/script.h>
|
#include <torch/script.h>
|
||||||
|
|
||||||
#include "pytorch_jni_common.h"
|
#include "pytorch_jni_common.h"
|
||||||
@ -26,12 +27,30 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
return makeCxxInstance(modelPath);
|
return makeCxxInstance(modelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TRACE_ENABLED
|
||||||
|
static void onFunctionEnter(
|
||||||
|
const torch::autograd::profiler::RecordFunction& fn) {
|
||||||
|
Trace::beginSection(fn.name().str());
|
||||||
|
}
|
||||||
|
|
||||||
|
static void onFunctionExit(const torch::autograd::profiler::RecordFunction&) {
|
||||||
|
Trace::endSection();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
|
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
|
||||||
auto qengines = at::globalContext().supportedQEngines();
|
auto qengines = at::globalContext().supportedQEngines();
|
||||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
|
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
|
||||||
qengines.end()) {
|
qengines.end()) {
|
||||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||||
}
|
}
|
||||||
|
#ifdef TRACE_ENABLED
|
||||||
|
torch::autograd::profiler::pushCallback(
|
||||||
|
&onFunctionEnter,
|
||||||
|
&onFunctionExit,
|
||||||
|
/* need_inputs */ false,
|
||||||
|
/* sampled */ false);
|
||||||
|
#endif
|
||||||
module_ = torch::jit::load(std::move(modelPath->toStdString()));
|
module_ = torch::jit::load(std::move(modelPath->toStdString()));
|
||||||
module_.eval();
|
module_.eval();
|
||||||
}
|
}
|
||||||
@ -48,6 +67,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
facebook::jni::alias_ref<
|
facebook::jni::alias_ref<
|
||||||
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
|
||||||
jinputs) {
|
jinputs) {
|
||||||
|
Trace _s{"jni::Module::forward"};
|
||||||
std::vector<at::IValue> inputs{};
|
std::vector<at::IValue> inputs{};
|
||||||
size_t n = jinputs->size();
|
size_t n = jinputs->size();
|
||||||
inputs.reserve(n);
|
inputs.reserve(n);
|
||||||
|
Reference in New Issue
Block a user