mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Original commit changeset: 96813f0fac68 Original Phabricator Diff: D50161780 This breaks the integration test on T166457344 Test Plan: Sandcastle. Differential Revision: D50344243 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111401 Approved by: https://github.com/izaitsevfb
22 lines
689 B
Python
22 lines
689 B
Python
import torch
|
|
import torchvision
|
|
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
|
|
from torch.utils.mobile_optimizer import optimize_for_mobile
|
|
|
|
|
|
class MobileNetV2Module:
|
|
def getModule(self):
|
|
model = torchvision.models.mobilenet_v2(pretrained=True)
|
|
model.eval()
|
|
example = torch.zeros(1, 3, 224, 224)
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
optimized_module = optimize_for_mobile(traced_script_module)
|
|
augment_model_with_bundled_inputs(
|
|
optimized_module,
|
|
[
|
|
(example, ),
|
|
],
|
|
)
|
|
optimized_module(example)
|
|
return optimized_module
|