mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT on a bunch of low traffic Python files outside of main files (#106052)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106052 Approved by: https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a114f72bf
commit
f70844bec7
@ -1,18 +1,21 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
OUTPUT_DIR = "src/androidTest/assets/"
|
||||
|
||||
|
||||
def scriptAndSave(module, fileName):
|
||||
print('-' * 80)
|
||||
print("-" * 80)
|
||||
script_module = torch.jit.script(module)
|
||||
print(script_module.graph)
|
||||
outputFileName = OUTPUT_DIR + fileName
|
||||
# note that the lite interpreter model can also be used in full JIT
|
||||
script_module._save_for_lite_interpreter(outputFileName)
|
||||
print("Saved to " + outputFileName)
|
||||
print('=' * 80)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
class Test(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
@ -73,7 +76,9 @@ class Test(torch.jit.ScriptModule):
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]:
|
||||
def tupleIntSumReturnTuple(
|
||||
self, input: Tuple[int, int, int]
|
||||
) -> Tuple[Tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
@ -114,7 +119,7 @@ class Test(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv2d(x, w)
|
||||
if (toChannelsLast):
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
@ -132,4 +137,5 @@ class Test(torch.jit.ScriptModule):
|
||||
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last_3d)
|
||||
|
||||
|
||||
scriptAndSave(Test(), "test.pt")
|
||||
|
@ -5,12 +5,20 @@ print(torch.version.__version__)
|
||||
|
||||
resnet18 = torchvision.models.resnet18(pretrained=True)
|
||||
resnet18.eval()
|
||||
resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save("app/src/main/assets/resnet18.pt")
|
||||
resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save(
|
||||
"app/src/main/assets/resnet18.pt"
|
||||
)
|
||||
|
||||
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||
resnet50.eval()
|
||||
torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save("app/src/main/assets/resnet50.pt")
|
||||
torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save(
|
||||
"app/src/main/assets/resnet50.pt"
|
||||
)
|
||||
|
||||
mobilenet2q = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True)
|
||||
mobilenet2q = torchvision.models.quantization.mobilenet_v2(
|
||||
pretrained=True, quantize=True
|
||||
)
|
||||
mobilenet2q.eval()
|
||||
torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save("app/src/main/assets/mobilenet2q.pt")
|
||||
torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save(
|
||||
"app/src/main/assets/mobilenet2q.pt"
|
||||
)
|
||||
|
@ -21,5 +21,5 @@ traced_script_module.save("MobileNetV2.pt")
|
||||
# Dump root ops used by the model (for custom build optimization).
|
||||
ops = torch.jit.export_opnames(traced_script_module)
|
||||
|
||||
with open('MobileNetV2.yaml', 'w') as output:
|
||||
with open("MobileNetV2.yaml", "w") as output:
|
||||
yaml.dump(ops, output)
|
||||
|
Reference in New Issue
Block a user