[BE] Remove Model Dump utility (#141540)

So I found this utility by accident, trying to find how many html files we have in the repo so I could convert them to markdown

Turns out we package some html and js files in pytorch to visualize torchscript models. This seems kinda strange, probably shouldn't be in core, I removed the tests I could find. Maybe some internal tests will break but considering torchscript is being superseded might make sense to do this

Last time there was a meaningful update to the test for this file was about 2 years ago by @digantdesai since then it's a bunch of routine upgrades

It seems like this package is unused https://github.com/search?type=code&auto_enroll=true&q=torch.utils.model_dump&p=1 I skimmed through 5 pages of these and the only time this shows up in code search is when someone is either cloning pytorch or checking in their venv into github
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141540
Approved by: https://github.com/malfet
This commit is contained in:
Mark Saroufim
2024-11-27 22:52:53 +00:00
committed by PyTorch MergeBot
parent 533798ef46
commit e24190709f
5 changed files with 2 additions and 246 deletions

View File

@ -156,7 +156,6 @@ coverage_ignore_functions = [
"DistributedDataParallelCPU",
# torch.utils
"set_module",
# torch.utils.model_dump
"burn_in_info",
"get_info_and_burn_skeleton",
"get_inline_skeleton",

View File

@ -433,7 +433,6 @@ S390X_TESTLIST = [
"test_mkldnn_verbose",
"test_mkl_verbose",
"test_mobile_optimizer",
"test_model_dump",
"test_model_exports_to_core_aten",
"test_module_tracker",
"test_monitor",

View File

@ -1,242 +0,0 @@
#!/usr/bin/env python3
# Owner(s): ["oncall: mobile"]
import os
import io
import functools
import tempfile
import urllib
import unittest
import torch
import torch.backends.xnnpack
import torch.utils.model_dump
import torch.utils.mobile_optimizer
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
from torch.testing._internal.common_quantized import supported_qengines
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer1 = torch.nn.Linear(16, 64)
self.relu1 = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(64, 8)
self.relu2 = torch.nn.ReLU()
def forward(self, features):
act = features
act = self.layer1(act)
act = self.relu1(act)
act = self.layer2(act)
act = self.relu2(act)
return act
class QuantModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
self.core = SimpleModel()
def forward(self, x):
x = self.quant(x)
x = self.core(x)
x = self.dequant(x)
return x
class ModelWithLists(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.rt = [torch.zeros(1)]
self.ot = [torch.zeros(1), None]
def forward(self, arg):
arg = arg + self.rt[0]
o = self.ot[0]
if o is not None:
arg = arg + o
return arg
def webdriver_test(testfunc):
@functools.wraps(testfunc)
def wrapper(self, *args, **kwds):
self.needs_resources()
if os.environ.get("RUN_WEBDRIVER") != "1":
self.skipTest("Webdriver not requested")
from selenium import webdriver
for driver in [
"Firefox",
"Chrome",
]:
with self.subTest(driver=driver):
wd = getattr(webdriver, driver)()
testfunc(self, wd, *args, **kwds)
wd.close()
return wrapper
class TestModelDump(TestCase):
def needs_resources(self):
pass
def test_inline_skeleton(self):
self.needs_resources()
skel = torch.utils.model_dump.get_inline_skeleton()
assert "unpkg.org" not in skel
assert "src=" not in skel
def do_dump_model(self, model, extra_files=None):
# Just check that we're able to run successfully.
buf = io.BytesIO()
torch.jit.save(model, buf, _extra_files=extra_files)
info = torch.utils.model_dump.get_model_info(buf)
assert info is not None
def open_html_model(self, wd, model, extra_files=None):
buf = io.BytesIO()
torch.jit.save(model, buf, _extra_files=extra_files)
page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
def open_section_and_get_body(self, wd, name):
container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
caret = container.find_element_by_class_name("caret")
if container.get_attribute("data-shown") != "true":
caret.click()
content = container.find_element_by_tag_name("div")
return content
def test_scripted_model(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model)
def test_traced_model(self):
model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
self.do_dump_model(model)
def test_main(self):
self.needs_resources()
if IS_WINDOWS:
# I was getting tempfile errors in CI. Just skip it.
self.skipTest("Disabled on Windows.")
with tempfile.NamedTemporaryFile() as tf:
torch.jit.save(torch.jit.script(SimpleModel()), tf)
# Actually write contents to disk so we can read it below
tf.flush()
stdout = io.StringIO()
torch.utils.model_dump.main(
[
None,
"--style=json",
tf.name,
],
stdout=stdout)
self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
stdout = io.StringIO()
torch.utils.model_dump.main(
[
None,
"--style=html",
tf.name,
],
stdout=stdout)
self.assertRegex(
stdout.getvalue().replace("\n", " "),
r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
def get_quant_model(self):
fmodel = QuantModel().eval()
fmodel = torch.ao.quantization.fuse_modules(fmodel, [
["core.layer1", "core.relu1"],
["core.layer2", "core.relu2"],
])
fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
prepped = torch.ao.quantization.prepare(fmodel)
prepped(torch.randn(2, 16))
qmodel = torch.ao.quantization.convert(prepped)
return qmodel
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_quantized_model(self):
qmodel = self.get_quant_model()
self.do_dump_model(torch.jit.script(qmodel))
@skipIfNoXNNPACK
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_optimized_quantized_model(self):
qmodel = self.get_quant_model()
smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
self.do_dump_model(omodel)
def test_model_with_lists(self):
model = torch.jit.script(ModelWithLists())
self.do_dump_model(model)
def test_invalid_json(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model, extra_files={"foo.json": "{"})
@webdriver_test
def test_memory_computation(self, wd):
def check_memory(model, expected):
self.open_html_model(wd, model)
memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
self.assertEqual("cpu", device)
memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
self.assertEqual(expected, int(memory_usage_str))
simple_model_memory = (
# First layer, including bias.
64 * (16 + 1) +
# Second layer, including bias.
8 * (64 + 1)
# 32-bit float
) * 4
check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
# The same SimpleModel instance appears twice in this model.
# The tensors will be shared, so ensure no double-counting.
a_simple_model = SimpleModel()
check_memory(
torch.jit.script(
torch.nn.Sequential(a_simple_model, a_simple_model)),
simple_model_memory)
# The freezing process will move the weight and bias
# from data to constants. Ensure they are still counted.
check_memory(
torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
simple_model_memory)
# Make sure we can handle a model with both constants and data tensors.
class ComposedModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = torch.zeros(1, 2)
self.w2 = torch.ones(2, 2)
def forward(self, arg):
return arg * self.w2 + self.w1
check_memory(
torch.jit.freeze(
torch.jit.script(ComposedModule()).eval(),
preserved_attrs=["w1"]),
4 * (2 + 4))
if __name__ == '__main__':
run_tests()

View File

@ -278,8 +278,6 @@ class TestPublicBindings(TestCase):
for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror):
modname = mod.name
try:
# TODO: fix "torch/utils/model_dump/__main__.py"
# which calls sys.exit() when we try to import it
if "__main__" in modname:
continue
importlib.import_module(modname)

View File

@ -76,6 +76,7 @@ import urllib.parse
import zipfile
from pathlib import Path
from typing import Dict
import warnings
import torch.utils.show_pickle
@ -389,6 +390,7 @@ def get_info_and_burn_skeleton(path_or_bytesio, **kwargs):
def main(argv, *, stdout=None):
warnings.warn("torch.utils.model_dump is deprecated and will be removed in a future PyTorch release.")
parser = argparse.ArgumentParser()
parser.add_argument("--style", choices=["json", "html"])
parser.add_argument("--title")