mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
243 lines
7.8 KiB
Python
243 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
|
# Owner(s): ["oncall: mobile"]
|
|
|
|
import os
|
|
import io
|
|
import functools
|
|
import tempfile
|
|
import urllib
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.backends.xnnpack
|
|
import torch.utils.model_dump
|
|
import torch.utils.mobile_optimizer
|
|
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
|
|
from torch.testing._internal.common_quantized import supported_qengines
|
|
|
|
|
|
class SimpleModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(16, 64)
|
|
self.relu1 = torch.nn.ReLU()
|
|
self.layer2 = torch.nn.Linear(64, 8)
|
|
self.relu2 = torch.nn.ReLU()
|
|
|
|
def forward(self, features):
|
|
act = features
|
|
act = self.layer1(act)
|
|
act = self.relu1(act)
|
|
act = self.layer2(act)
|
|
act = self.relu2(act)
|
|
return act
|
|
|
|
|
|
class QuantModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.quant = torch.ao.quantization.QuantStub()
|
|
self.dequant = torch.ao.quantization.DeQuantStub()
|
|
self.core = SimpleModel()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.core(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
|
|
class ModelWithLists(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.rt = [torch.zeros(1)]
|
|
self.ot = [torch.zeros(1), None]
|
|
|
|
def forward(self, arg):
|
|
arg = arg + self.rt[0]
|
|
o = self.ot[0]
|
|
if o is not None:
|
|
arg = arg + o
|
|
return arg
|
|
|
|
|
|
def webdriver_test(testfunc):
|
|
@functools.wraps(testfunc)
|
|
def wrapper(self, *args, **kwds):
|
|
self.needs_resources()
|
|
|
|
if os.environ.get("RUN_WEBDRIVER") != "1":
|
|
self.skipTest("Webdriver not requested")
|
|
from selenium import webdriver
|
|
|
|
for driver in [
|
|
"Firefox",
|
|
"Chrome",
|
|
]:
|
|
with self.subTest(driver=driver):
|
|
wd = getattr(webdriver, driver)()
|
|
testfunc(self, wd, *args, **kwds)
|
|
wd.close()
|
|
|
|
return wrapper
|
|
|
|
|
|
class TestModelDump(TestCase):
|
|
def needs_resources(self):
|
|
pass
|
|
|
|
def test_inline_skeleton(self):
|
|
self.needs_resources()
|
|
skel = torch.utils.model_dump.get_inline_skeleton()
|
|
assert "unpkg.org" not in skel
|
|
assert "src=" not in skel
|
|
|
|
def do_dump_model(self, model, extra_files=None):
|
|
# Just check that we're able to run successfully.
|
|
buf = io.BytesIO()
|
|
torch.jit.save(model, buf, _extra_files=extra_files)
|
|
info = torch.utils.model_dump.get_model_info(buf)
|
|
assert info is not None
|
|
|
|
def open_html_model(self, wd, model, extra_files=None):
|
|
buf = io.BytesIO()
|
|
torch.jit.save(model, buf, _extra_files=extra_files)
|
|
page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
|
|
wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
|
|
|
|
def open_section_and_get_body(self, wd, name):
|
|
container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
|
|
caret = container.find_element_by_class_name("caret")
|
|
if container.get_attribute("data-shown") != "true":
|
|
caret.click()
|
|
content = container.find_element_by_tag_name("div")
|
|
return content
|
|
|
|
def test_scripted_model(self):
|
|
model = torch.jit.script(SimpleModel())
|
|
self.do_dump_model(model)
|
|
|
|
def test_traced_model(self):
|
|
model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
|
|
self.do_dump_model(model)
|
|
|
|
def test_main(self):
|
|
self.needs_resources()
|
|
if IS_WINDOWS:
|
|
# I was getting tempfile errors in CI. Just skip it.
|
|
self.skipTest("Disabled on Windows.")
|
|
|
|
with tempfile.NamedTemporaryFile() as tf:
|
|
torch.jit.save(torch.jit.script(SimpleModel()), tf)
|
|
# Actually write contents to disk so we can read it below
|
|
tf.flush()
|
|
|
|
stdout = io.StringIO()
|
|
torch.utils.model_dump.main(
|
|
[
|
|
None,
|
|
"--style=json",
|
|
tf.name,
|
|
],
|
|
stdout=stdout)
|
|
self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
|
|
|
|
stdout = io.StringIO()
|
|
torch.utils.model_dump.main(
|
|
[
|
|
None,
|
|
"--style=html",
|
|
tf.name,
|
|
],
|
|
stdout=stdout)
|
|
self.assertRegex(
|
|
stdout.getvalue().replace("\n", " "),
|
|
r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
|
|
|
|
def get_quant_model(self):
|
|
fmodel = QuantModel().eval()
|
|
fmodel = torch.ao.quantization.fuse_modules(fmodel, [
|
|
["core.layer1", "core.relu1"],
|
|
["core.layer2", "core.relu2"],
|
|
])
|
|
fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
|
|
prepped = torch.ao.quantization.prepare(fmodel)
|
|
prepped(torch.randn(2, 16))
|
|
qmodel = torch.ao.quantization.convert(prepped)
|
|
return qmodel
|
|
|
|
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
|
|
def test_quantized_model(self):
|
|
qmodel = self.get_quant_model()
|
|
self.do_dump_model(torch.jit.script(qmodel))
|
|
|
|
@skipIfNoXNNPACK
|
|
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
|
|
def test_optimized_quantized_model(self):
|
|
qmodel = self.get_quant_model()
|
|
smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
|
|
omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
|
|
self.do_dump_model(omodel)
|
|
|
|
def test_model_with_lists(self):
|
|
model = torch.jit.script(ModelWithLists())
|
|
self.do_dump_model(model)
|
|
|
|
def test_invalid_json(self):
|
|
model = torch.jit.script(SimpleModel())
|
|
self.do_dump_model(model, extra_files={"foo.json": "{"})
|
|
|
|
@webdriver_test
|
|
def test_memory_computation(self, wd):
|
|
def check_memory(model, expected):
|
|
self.open_html_model(wd, model)
|
|
memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
|
|
device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
|
|
self.assertEqual("cpu", device)
|
|
memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
|
|
self.assertEqual(expected, int(memory_usage_str))
|
|
|
|
simple_model_memory = (
|
|
# First layer, including bias.
|
|
64 * (16 + 1) +
|
|
# Second layer, including bias.
|
|
8 * (64 + 1)
|
|
# 32-bit float
|
|
) * 4
|
|
|
|
check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
|
|
|
|
# The same SimpleModel instance appears twice in this model.
|
|
# The tensors will be shared, so ensure no double-counting.
|
|
a_simple_model = SimpleModel()
|
|
check_memory(
|
|
torch.jit.script(
|
|
torch.nn.Sequential(a_simple_model, a_simple_model)),
|
|
simple_model_memory)
|
|
|
|
# The freezing process will move the weight and bias
|
|
# from data to constants. Ensure they are still counted.
|
|
check_memory(
|
|
torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
|
|
simple_model_memory)
|
|
|
|
# Make sure we can handle a model with both constants and data tensors.
|
|
class ComposedModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w1 = torch.zeros(1, 2)
|
|
self.w2 = torch.ones(2, 2)
|
|
|
|
def forward(self, arg):
|
|
return arg * self.w2 + self.w1
|
|
|
|
check_memory(
|
|
torch.jit.freeze(
|
|
torch.jit.script(ComposedModule()).eval(),
|
|
preserved_attrs=["w1"]),
|
|
4 * (2 + 4))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|