mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56868 See __init__.py for a summary of the tool. The following sections are present in this initial version - Model Size. Show the total model size, as well as a breakdown by stored files, compressed files, and zip overhead. (I expect this breakdown to be a bit more useful once data.pkl is compressed.) - Model Structure. This is basically the output of `show_pickle(data.pkl)`, but as a hierarchical structure. Some structures cause this view to crash right now, but it can be improved incrementally. - Zip Contents. This is basically the output of `zipinfo -l`. - Code. This is the TorchScript code. It's integrated with a blame window at the bottom, so you can click "Blame Code", then click a bit of code to see where it came from (based on the debug_pkl). This currently doesn't render properly if debug_pkl is missing or incomplete. - Extra files (JSON). JSON dumps of each json file under /extra/, up to a size limit. - Extra Pickles. For each .pkl file in the model, we safely unpickle it with `show_pickle`, then render it with `pprint` and include it here if the size is not too large. We aren't able to install the pprint hack that thw show_pickle CLI uses, so we get one-line rendering for custom objects, which is not very useful. Built-in types look fine, though. In particular, bytecode.pkl seems to look fine (and we hard-code that file to ignore the size limit). I'm checking in the JS dependencies to avoid a network dependency at runtime. They were retrieved from the following URLS, then passed through a JS minifier: https://unpkg.com/htm@3.0.4/dist/htm.module.js?module https://unpkg.com/preact@10.5.13/dist/preact.module.js?module Test Plan: Manually ran on a few models I had lying around. Mostly tested in Chrome, but I also poked around in Firefox. Reviewed By: dhruvbird Differential Revision: D28020849 Pulled By: dreiss fbshipit-source-id: 421c30ed7ca55244e9fda1a03b8aab830466536d
115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
import sys
|
|
import io
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.utils.model_dump
|
|
import torch.utils.mobile_optimizer
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from torch.testing._internal.common_quantized import supported_qengines
|
|
|
|
|
|
class SimpleModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(16, 64)
|
|
self.relu1 = torch.nn.ReLU()
|
|
self.layer2 = torch.nn.Linear(64, 8)
|
|
self.relu2 = torch.nn.ReLU()
|
|
|
|
def forward(self, features):
|
|
act = features
|
|
act = self.layer1(act)
|
|
act = self.relu1(act)
|
|
act = self.layer2(act)
|
|
act = self.relu2(act)
|
|
return act
|
|
|
|
|
|
class QuantModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.quant = torch.quantization.QuantStub()
|
|
self.dequant = torch.quantization.DeQuantStub()
|
|
self.core = SimpleModel()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.core(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
|
|
class ModelWithLists(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rt = [torch.zeros(1)]
|
|
self.ot = [torch.zeros(1), None]
|
|
|
|
def forward(self, arg):
|
|
arg = arg + self.rt[0]
|
|
o = self.ot[0]
|
|
if o is not None:
|
|
arg = arg + o
|
|
return arg
|
|
|
|
|
|
class TestModelDump(TestCase):
|
|
@unittest.skipIf(sys.version_info < (3, 7), "importlib.resources was new in 3.7")
|
|
def test_inline_skeleton(self):
|
|
skel = torch.utils.model_dump.get_inline_skeleton()
|
|
assert "unpkg.org" not in skel
|
|
assert "src=" not in skel
|
|
|
|
def do_dump_model(self, model, extra_files=None):
|
|
# Just check that we're able to run successfully.
|
|
buf = io.BytesIO()
|
|
torch.jit.save(model, buf, _extra_files=extra_files)
|
|
info = torch.utils.model_dump.get_model_info(buf)
|
|
assert info is not None
|
|
|
|
def test_scripted_model(self):
|
|
model = torch.jit.script(SimpleModel())
|
|
self.do_dump_model(model)
|
|
|
|
def test_traced_model(self):
|
|
model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
|
|
self.do_dump_model(model)
|
|
|
|
def get_quant_model(self):
|
|
fmodel = QuantModel().eval()
|
|
fmodel = torch.quantization.fuse_modules(fmodel, [
|
|
["core.layer1", "core.relu1"],
|
|
["core.layer2", "core.relu2"],
|
|
])
|
|
fmodel.qconfig = torch.quantization.get_default_qconfig("qnnpack")
|
|
prepped = torch.quantization.prepare(fmodel)
|
|
prepped(torch.randn(2, 16))
|
|
qmodel = torch.quantization.convert(prepped)
|
|
return qmodel
|
|
|
|
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
|
|
def test_quantized_model(self):
|
|
qmodel = self.get_quant_model()
|
|
self.do_dump_model(torch.jit.script(qmodel))
|
|
|
|
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
|
|
def test_optimized_quantized_model(self):
|
|
qmodel = self.get_quant_model()
|
|
smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
|
|
omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
|
|
self.do_dump_model(omodel)
|
|
|
|
def test_model_with_lists(self):
|
|
model = torch.jit.script(ModelWithLists())
|
|
self.do_dump_model(model)
|
|
|
|
def test_invalid_json(self):
|
|
model = torch.jit.script(SimpleModel())
|
|
self.do_dump_model(model, extra_files={"foo.json": "{"})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|