diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle index c592728ce9f4..9be0ad2d7aa6 100644 --- a/android/test_app/app/build.gradle +++ b/android/test_app/app/build.gradle @@ -40,6 +40,7 @@ android { buildConfigField("String", "LOGCAT_TAG", "@string/app_name") buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}") buildConfigField("boolean", "NATIVE_BUILD", 'false') + buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false') addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"]) } buildTypes { @@ -66,9 +67,17 @@ android { addManifestPlaceholders([APP_NAME: "MBQ"]) buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"") } + mbvulkan { + dimension "model" + applicationIdSuffix ".mbvulkan" + buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"") + buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true') + addManifestPlaceholders([APP_NAME: "MBQ"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"") + } resnet18 { dimension "model" - applicationIdSuffix ".resneti18" + applicationIdSuffix ".resnet18" buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"") addManifestPlaceholders([APP_NAME: "RN18"]) buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"") diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java index 5cc233011c8a..1510890a15b8 100644 --- a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.FloatBuffer; +import org.pytorch.Device; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.PyTorchAndroid; @@ -126,7 +127,9 @@ public class MainActivity extends AppCompatActivity { mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements); mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE); PyTorchAndroid.setNumThreads(1); - mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); + mModule = BuildConfig.USE_VULKAN_DEVICE + ? PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN) + : PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); } final long startTime = SystemClock.elapsedRealtime();