mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This adds 2 jobs to build PyTorch Android with and without lite interpreter: * Keep the list of currently supported ABI armeabi-v7a, arm64-v8a, x86, x86_64 * Pass all the test on emulator * Run an the test app on emulator and my Android phone `arm64-v8a` without any issue  * Run on AWS https://us-west-2.console.aws.amazon.com/devicefarm/home#/mobile/projects/b531574a-fb82-40ae-b687-8f0b81341ae0/runs/5fce6818-628a-4099-9aab-23e91a212076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110976 Approved by: https://github.com/atalman
56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
import torch
|
|
import torchvision
|
|
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
|
|
from torch.utils.mobile_optimizer import optimize_for_mobile
|
|
|
|
|
|
class MobileNetV2Module:
|
|
def getModule(self):
|
|
model = torchvision.models.mobilenet_v2(pretrained=True)
|
|
model.eval()
|
|
example = torch.zeros(1, 3, 224, 224)
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
optimized_module = optimize_for_mobile(traced_script_module)
|
|
augment_model_with_bundled_inputs(
|
|
optimized_module,
|
|
[
|
|
(example, ),
|
|
],
|
|
)
|
|
optimized_module(example)
|
|
return optimized_module
|
|
|
|
|
|
class MobileNetV2VulkanModule:
|
|
def getModule(self):
|
|
model = torchvision.models.mobilenet_v2(pretrained=True)
|
|
model.eval()
|
|
example = torch.zeros(1, 3, 224, 224)
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan")
|
|
augment_model_with_bundled_inputs(
|
|
optimized_module,
|
|
[
|
|
(example, ),
|
|
],
|
|
)
|
|
optimized_module(example)
|
|
return optimized_module
|
|
|
|
|
|
class Resnet18Module:
|
|
def getModule(self):
|
|
model = torchvision.models.resnet18(pretrained=True)
|
|
model.eval()
|
|
example = torch.zeros(1, 3, 224, 224)
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
optimized_module = optimize_for_mobile(traced_script_module)
|
|
augment_model_with_bundled_inputs(
|
|
optimized_module,
|
|
[
|
|
(example, ),
|
|
],
|
|
)
|
|
optimized_module(example)
|
|
return optimized_module
|