[vulkan][android][test_app] Add test_app variant that runs module on Vulkan (#44897)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44897

Test Plan: Imported from OSS

Reviewed By: dreiss

Differential Revision: D23763770

Pulled By: IvanKobzarev

fbshipit-source-id: 6ad16b7271c745313a71da64a629a764258bbc85
This commit is contained in:
Ivan Kobzarev
2020-09-29 09:56:01 -07:00
committed by Facebook GitHub Bot
parent 2c300fd74c
commit 17be7c6e5c
2 changed files with 14 additions and 2 deletions

View File

@ -40,6 +40,7 @@ android {
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
@ -66,9 +67,17 @@ android {
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
}
mbvulkan {
dimension "model"
applicationIdSuffix ".mbvulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
}
resnet18 {
dimension "model"
applicationIdSuffix ".resneti18"
applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")

View File

@ -17,6 +17,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
@ -126,7 +127,9 @@ public class MainActivity extends AppCompatActivity {
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
PyTorchAndroid.setNumThreads(1);
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
mModule = BuildConfig.USE_VULKAN_DEVICE
? PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
}
final long startTime = SystemClock.elapsedRealtime();