mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[package] Add dependency viz (#45214)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45214 When in verbose mode the package exporter will produce an html visualization of dependencies of a module to make it easier to trim out unneeded code, or debug inclusion of things that cannot be exported. Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D23873525 Pulled By: zdevito fbshipit-source-id: 6801991573d8dd5ab8c284e09572b36a35e1e5a4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
331ebaf7cb
commit
e54e1fe51e
@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
import torch
|
||||
from sys import version_info
|
||||
from io import StringIO
|
||||
|
||||
try:
|
||||
from torchvision.models import resnet18
|
||||
@ -183,6 +184,11 @@ b = resources.load_binary('main', 'main_binary')
|
||||
# the objects in the pickle
|
||||
e.save_pickle('model', 'model.pkl', resnet)
|
||||
|
||||
# check th debug graph has something reasonable:
|
||||
buf = StringIO()
|
||||
e._write_dep_graph(failing_module='torch', output_file=buf)
|
||||
self.assertIn('torchvision.models.resnet', buf.getvalue())
|
||||
|
||||
# we can now load the saved model
|
||||
i = PackageImporter(f1)
|
||||
r2 = i.load_pickle('model', 'model.pkl')
|
||||
|
@ -8,11 +8,12 @@ from ._custom_import_pickler import CustomImportPickler
|
||||
from ._importlib import _normalize_path
|
||||
import types
|
||||
import importlib
|
||||
from typing import List, Any, Callable, Dict
|
||||
from typing import List, Any, Callable, Dict, Tuple
|
||||
from distutils.sysconfig import get_python_lib
|
||||
from pathlib import Path
|
||||
import linecache
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
class PackageExporter:
|
||||
""" Exporters allow you to write packages of code, pickled python data, and
|
||||
@ -70,6 +71,7 @@ class PackageExporter:
|
||||
self.provided : Dict[str, bool] = {}
|
||||
self.verbose = verbose
|
||||
self.importers = [importlib.import_module]
|
||||
self.debug_deps : List[Tuple[str, str]] = []
|
||||
|
||||
def save_source_file(self, module_name: str, file_or_directory: str, dependencies=True):
|
||||
"""Adds the local file system `file_or_directory` to the source package to provide the code
|
||||
@ -131,15 +133,9 @@ class PackageExporter:
|
||||
self._write(filename, src)
|
||||
if dependencies:
|
||||
package = module_name if is_package else module_name.rsplit('.', maxsplit=1)[0]
|
||||
dep_list = find_files_source_depends_on(src, package)
|
||||
if self.verbose:
|
||||
def fmt_dep(mod, obj):
|
||||
return f'{mod}' if obj is None else f'{mod}.{obj}'
|
||||
dep_str = ''.join(f' {fmt_dep(mod, obj)}\n' for mod, obj in dep_list)
|
||||
file_info = f'(from file {orig_file_name}) ' if orig_file_name is not None else ''
|
||||
print(f"{module_name} {file_info}depends on:\n{dep_str}\n")
|
||||
|
||||
for dep_module_name, dep_module_obj in dep_list:
|
||||
dep_pairs = find_files_source_depends_on(src, package)
|
||||
dep_list = {}
|
||||
for dep_module_name, dep_module_obj in dep_pairs:
|
||||
# handle the case where someone did something like `from pack import sub`
|
||||
# where `sub` is a submodule. In this case we don't have to save pack, just sub.
|
||||
# this ensures we don't pick up additional dependencies on pack.
|
||||
@ -148,25 +144,117 @@ class PackageExporter:
|
||||
if dep_module_obj is not None:
|
||||
possible_submodule = f'{dep_module_name}.{dep_module_obj}'
|
||||
if self._module_exists(possible_submodule):
|
||||
self.require_module_if_not_provided(possible_submodule)
|
||||
dep_list[possible_submodule] = True
|
||||
# we don't need to save `pack`
|
||||
continue
|
||||
if self._module_exists(dep_module_name):
|
||||
self.require_module_if_not_provided(dep_module_name)
|
||||
dep_list[dep_module_name] = True
|
||||
|
||||
for dep in dep_list.keys():
|
||||
self.debug_deps.append((module_name, dep))
|
||||
|
||||
if self.verbose:
|
||||
dep_str = ''.join(f' {dep}\n' for dep in dep_list.keys())
|
||||
file_info = f'(from file {orig_file_name}) ' if orig_file_name is not None else ''
|
||||
print(f"{module_name} {file_info}depends on:\n{dep_str}\n")
|
||||
|
||||
for dep in dep_list.keys():
|
||||
self.require_module_if_not_provided(dep)
|
||||
|
||||
def _module_exists(self, module_name: str) -> bool:
|
||||
try:
|
||||
self._import_module(module_name)
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _write_dep_graph(self, failing_module=None, output_file=None):
|
||||
depended_on : Dict[str, List[str]] = {}
|
||||
for f, t in self.debug_deps:
|
||||
if t not in depended_on:
|
||||
depended_on[t] = []
|
||||
if f not in depended_on:
|
||||
depended_on[f] = []
|
||||
depended_on[t].append(f)
|
||||
|
||||
level : Dict[str, int] = {}
|
||||
|
||||
def visit(x: str):
|
||||
if x in level:
|
||||
return level[x]
|
||||
level[x] = 0
|
||||
for e in depended_on[x]:
|
||||
level[x] = max(level[x], visit(e) + 1)
|
||||
return level[x]
|
||||
|
||||
for x in depended_on.keys():
|
||||
visit(x)
|
||||
|
||||
nodes = []
|
||||
node_to_id = {}
|
||||
n = 0
|
||||
for ft in self.debug_deps:
|
||||
for e in ft:
|
||||
if e not in node_to_id:
|
||||
node_to_id[e] = n
|
||||
extra = ''
|
||||
if e == failing_module:
|
||||
extra = ", color: 'red'"
|
||||
nodes.append(f" {{id: {n}, label: '{e}', level: {level[e]}, shape: 'box'{extra}}},\n")
|
||||
n += 1
|
||||
edges = []
|
||||
for f, t in self.debug_deps:
|
||||
fn, tn = node_to_id[f], node_to_id[t]
|
||||
edges.append(f" {{from: {fn}, to: {tn}, arrows: 'to'}},\n")
|
||||
nodes_s, edges_s = ''.join(nodes), ''.join(edges)
|
||||
template = f"""\
|
||||
<html>
|
||||
<head>
|
||||
<script type="text/javascript" src="https://almende.github.io/vis/dist/vis.js"></script>
|
||||
<link href="https://almende.github.io/vis/dist/vis.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="mynetwork"></div>
|
||||
|
||||
<script type="text/javascript">
|
||||
var nodes = new vis.DataSet([
|
||||
{nodes_s}
|
||||
]);
|
||||
var edges = new vis.DataSet([
|
||||
{edges_s}
|
||||
]);
|
||||
var options = {{
|
||||
layout: {{
|
||||
hierarchical: {{
|
||||
direction: "LR",
|
||||
levelSeparation: 400,
|
||||
}},
|
||||
}},
|
||||
}};
|
||||
// create a network
|
||||
var container = document.getElementById('mynetwork');
|
||||
var network = new vis.Network(container, {{nodes: nodes, edges: edges}}, options);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
if output_file:
|
||||
output_file.write(template)
|
||||
return None
|
||||
|
||||
with NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tf:
|
||||
tf.write(template)
|
||||
return tf.name
|
||||
|
||||
def _get_source_of_module(self, module: types.ModuleType) -> str:
|
||||
filename = getattr(module, '__file__', None)
|
||||
result = None if filename is None else linecache.getlines(filename, module.__dict__)
|
||||
result = None if filename is None or not filename.endswith('.py') else linecache.getlines(filename, module.__dict__)
|
||||
if result is None:
|
||||
extra = ''
|
||||
if self.verbose:
|
||||
extra = f' See the dependency graph for more info: {self._write_dep_graph(module.__name__)}'
|
||||
raise ValueError(f'cannot save source for module "{module.__name__}" because '
|
||||
f'its source file "{filename}" could not be found.')
|
||||
f'its source file "{filename}" could not be found.{extra}')
|
||||
return ''.join(result)
|
||||
|
||||
def require_module_if_not_provided(self, module_name: str, dependencies=True):
|
||||
@ -211,6 +299,7 @@ class PackageExporter:
|
||||
return import_module(module_name)
|
||||
except ModuleNotFoundError as err:
|
||||
last_err = err
|
||||
|
||||
if last_err is not None:
|
||||
raise last_err
|
||||
else:
|
||||
@ -258,6 +347,9 @@ class PackageExporter:
|
||||
if module not in all_dependencies:
|
||||
all_dependencies.append(module)
|
||||
|
||||
for dep in all_dependencies:
|
||||
self.debug_deps.append((package + '.' + resource, dep))
|
||||
|
||||
if self.verbose:
|
||||
dep_string = ''.join(f' {dep}\n' for dep in all_dependencies)
|
||||
print(f"{resource} depends on:\n{dep_string}\n")
|
||||
@ -377,6 +469,9 @@ class PackageExporter:
|
||||
with PackageExporter("file.zip") as e:
|
||||
...
|
||||
"""
|
||||
if self.verbose:
|
||||
print(f"Dependency graph for exported package: {self._write_dep_graph()}")
|
||||
|
||||
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
|
||||
for key in sorted(self.serialized_storages.keys()):
|
||||
name = 'data/{}'.format(key)
|
||||
@ -395,6 +490,7 @@ class PackageExporter:
|
||||
self._write('extern_modules', contents)
|
||||
del self.zip_file
|
||||
|
||||
|
||||
def _filename(self, package, resource):
|
||||
package_path = package.replace('.', '/')
|
||||
resource = _normalize_path(resource)
|
||||
@ -412,7 +508,7 @@ _DISALLOWED_MODULES = ['sys', 'io']
|
||||
def _is_builtin_or_stdlib_module(module: types.ModuleType) -> bool:
|
||||
if module.__name__ in sys.builtin_module_names:
|
||||
return True
|
||||
filename = module.__file__
|
||||
filename = getattr(module, '__file__', None)
|
||||
if filename is None:
|
||||
return False
|
||||
standard_lib = get_python_lib(standard_lib=True)
|
||||
|
Reference in New Issue
Block a user