Subscribe for record function and if android do atrace (#28708)

Summary:
ghstack-source-id: 5edaf471557c25098ca0547229f2763760866887
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28708

Some cpp formatting changes as I run `clang-format -i`

Testing on devserver:
make assets (models):
```
pushd android/test_app/; python make_assets.py; popd
```
Build test_app apk:
```
TRACE_ENABLED=1 sh android/build_test_app.sh

find . -type f -name *apk
./android/test_app/app/build/outputs/apk/mobNet2Quant/debug/test_app-mobNet2Quant-debug.apk
./android/test_app/app/build/outputs/apk/resnet18/debug/test_app-resnet18-debug.apk
```

Install apk:
`adb install -r test_app-mobNet2Quant-debug.apk`
Run app on the device.
Systrace:
```
$ANDROID_HOME/platform-tools/systrace/systrace.py -t 10 -a org.pytorch.testapp.mobNet2Quant sched freq idle am wm gfx view binder_driver hal dalvik camera input res -o trace.html
```
trace.html contains sections like `jni::Module::forward`

![Screenshot 2019-11-12 18 36 30](https://user-images.githubusercontent.com/6638825/68728156-5d245580-057b-11ea-9e71-e47681894fe4.png)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28712

Differential Revision: D18495898

Pulled By: IvanKobzarev

fbshipit-source-id: 0bced4a442f9dd90525520972a2c1f5d51f57df3
This commit is contained in:
Ivan Kobzarev
2019-11-13 20:53:20 -08:00
committed by Facebook Github Bot
parent a68c52494c
commit aa6e992ffb
7 changed files with 121 additions and 2 deletions

View File

@ -47,7 +47,7 @@ echo "ABIS_LIST:$ABIS_LIST"
LIB_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs LIB_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs
INCLUDE_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include INCLUDE_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include
mkdir -p $LIB_DIR mkdir -p $LIB_DIR
rm $LIB_DIR/* rm -f $LIB_DIR/*
mkdir -p $INCLUDE_DIR mkdir -p $INCLUDE_DIR
for abi in $(echo $ABIS_LIST | tr ',' '\n') for abi in $(echo $ABIS_LIST | tr ',' '\n')

View File

@ -3,6 +3,17 @@ project(pytorch_jni CXX)
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
set(CMAKE_VERBOSE_MAKEFILE ON) set(CMAKE_VERBOSE_MAKEFILE ON)
set(TRACE_ENABLED OFF)
if(DEFINED ENV{TRACE_ENABLED})
if($ENV{TRACE_ENABLED} STREQUAL "1")
message(STATUS "TRACE_ENABLED ON")
set(TRACE_ENABLED ON)
endif()
endif()
if(NOT TRACE_ENABLED)
message(STATUS "TRACE_ENABLED OFF")
endif()
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
if (ANDROID_ABI) if (ANDROID_ABI)
@ -20,6 +31,10 @@ endif()
message(STATUS "libtorch dir:${libtorch_DIR}") message(STATUS "libtorch dir:${libtorch_DIR}")
configure_file(
${pytorch_android_DIR}/cmake_macros.h.in
${pytorch_android_DIR}/cmake_macros.h)
file(GLOB pytorch_android_SOURCES file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_jit.cpp ${pytorch_android_DIR}/pytorch_jni_jit.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp ${pytorch_android_DIR}/pytorch_jni_common.cpp

View File

@ -0,0 +1,3 @@
#pragma once
/* #undef TRACE_ENABLED */

View File

@ -0,0 +1,3 @@
#pragma once
#cmakedefine TRACE_ENABLED

View File

@ -10,6 +10,25 @@
namespace pytorch_jni { namespace pytorch_jni {
bool Trace::is_initialized_ = false;
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
Trace::fp_ATrace_endSection Trace::ATrace_endSection;
#endif
void Trace::init() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
if (lib != NULL) {
Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
dlsym(lib, "ATrace_beginSection"));
Trace::ATrace_endSection =
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
}
#endif
}
// NOTE: Codes must be kept in sync with DType.java. // NOTE: Codes must be kept in sync with DType.java.
// NOTE: Never serialize these, because they can change between releases. // NOTE: Never serialize these, because they can change between releases.
constexpr static int kTensorDTypeUInt8 = 1; constexpr static int kTensorDTypeUInt8 = 1;
@ -186,6 +205,7 @@ public:
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue( facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
const at::IValue& ivalue) { const at::IValue& ivalue) {
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
if (ivalue.isNone()) { if (ivalue.isNone()) {
static auto jMethodOptionalNull = static auto jMethodOptionalNull =
JIValue::javaClassStatic() JIValue::javaClassStatic()
@ -375,6 +395,7 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
at::IValue JIValue::JIValueToAtIValue( at::IValue JIValue::JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue) { facebook::jni::alias_ref<JIValue> jivalue) {
Trace _s{"jni::JIValue::JIValueToAtIValue"};
static const auto typeCodeField = static const auto typeCodeField =
JIValue::javaClassStatic()->getField<jint>("mTypeCode"); JIValue::javaClassStatic()->getField<jint>("mTypeCode");
const auto typeCode = jivalue->getFieldValue(typeCodeField); const auto typeCode = jivalue->getFieldValue(typeCodeField);

View File

@ -1,9 +1,66 @@
#include <fbjni/fbjni.h> #include <fbjni/fbjni.h>
#include <torch/csrc/api/include/torch/types.h> #include <torch/csrc/api/include/torch/types.h>
#include "cmake_macros.h"
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
#include <android/log.h>
#include <android/trace.h>
#include <dlfcn.h>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
#endif
namespace pytorch_jni { namespace pytorch_jni {
class Trace {
public:
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
typedef void* (*fp_ATrace_endSection)(void);
static fp_ATrace_beginSection ATrace_beginSection;
static fp_ATrace_endSection ATrace_endSection;
#endif
static void ensureInit() {
if (!Trace::is_initialized_) {
init();
Trace::is_initialized_ = true;
}
}
static void beginSection(const char* name) {
Trace::ensureInit();
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_beginSection(name);
#endif
}
static void endSection() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_endSection();
#endif
}
Trace(const char* name) {
ensureInit();
beginSection(name);
}
~Trace() {
endSection();
}
private:
static void init();
static bool is_initialized_;
};
class JIValue : public facebook::jni::JavaClass<JIValue> { class JIValue : public facebook::jni::JavaClass<JIValue> {
public: public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;"; constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
constexpr static int kTypeCodeNull = 1; constexpr static int kTypeCodeNull = 1;

View File

@ -6,6 +6,7 @@
#include <fbjni/ByteBuffer.h> #include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h> #include <fbjni/fbjni.h>
#include <torch/csrc/autograd/record_function.h>
#include <torch/script.h> #include <torch/script.h>
#include "pytorch_jni_common.h" #include "pytorch_jni_common.h"
@ -26,12 +27,30 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
return makeCxxInstance(modelPath); return makeCxxInstance(modelPath);
} }
#ifdef TRACE_ENABLED
static void onFunctionEnter(
const torch::autograd::profiler::RecordFunction& fn) {
Trace::beginSection(fn.name().str());
}
static void onFunctionExit(const torch::autograd::profiler::RecordFunction&) {
Trace::endSection();
}
#endif
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) { PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
auto qengines = at::globalContext().supportedQEngines(); auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
qengines.end()) { qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK); at::globalContext().setQEngine(at::QEngine::QNNPACK);
} }
#ifdef TRACE_ENABLED
torch::autograd::profiler::pushCallback(
&onFunctionEnter,
&onFunctionExit,
/* need_inputs */ false,
/* sampled */ false);
#endif
module_ = torch::jit::load(std::move(modelPath->toStdString())); module_ = torch::jit::load(std::move(modelPath->toStdString()));
module_.eval(); module_.eval();
} }
@ -48,6 +67,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
facebook::jni::alias_ref< facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject> facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) { jinputs) {
Trace _s{"jni::Module::forward"};
std::vector<at::IValue> inputs{}; std::vector<at::IValue> inputs{};
size_t n = jinputs->size(); size_t n = jinputs->size();
inputs.reserve(n); inputs.reserve(n);