mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
[vulkan][android][test_app] Add test_app variant that runs module on Vulkan (#44897)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44897 Test Plan: Imported from OSS Reviewed By: dreiss Differential Revision: D23763770 Pulled By: IvanKobzarev fbshipit-source-id: 6ad16b7271c745313a71da64a629a764258bbc85
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2c300fd74c
commit
17be7c6e5c
@ -40,6 +40,7 @@ android {
|
||||
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
|
||||
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
|
||||
buildConfigField("boolean", "NATIVE_BUILD", 'false')
|
||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
|
||||
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
|
||||
}
|
||||
buildTypes {
|
||||
@ -66,9 +67,17 @@ android {
|
||||
addManifestPlaceholders([APP_NAME: "MBQ"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
|
||||
}
|
||||
mbvulkan {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".mbvulkan"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
|
||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
|
||||
addManifestPlaceholders([APP_NAME: "MBQ"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
|
||||
}
|
||||
resnet18 {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".resneti18"
|
||||
applicationIdSuffix ".resnet18"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
|
||||
addManifestPlaceholders([APP_NAME: "RN18"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
|
||||
|
||||
@ -17,6 +17,7 @@ import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.pytorch.Device;
|
||||
import org.pytorch.IValue;
|
||||
import org.pytorch.Module;
|
||||
import org.pytorch.PyTorchAndroid;
|
||||
@ -126,7 +127,9 @@ public class MainActivity extends AppCompatActivity {
|
||||
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
|
||||
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
|
||||
PyTorchAndroid.setNumThreads(1);
|
||||
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
||||
mModule = BuildConfig.USE_VULKAN_DEVICE
|
||||
? PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
|
||||
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
||||
}
|
||||
|
||||
final long startTime = SystemClock.elapsedRealtime();
|
||||
|
||||
Reference in New Issue
Block a user