mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
move write_vis into contrib
This commit is contained in:
committed by
Soumith Chintala
parent
a194e66186
commit
5949bb27b5
0
torch/contrib/__init__.py
Normal file
0
torch/contrib/__init__.py
Normal file
112
torch/contrib/_graph_vis.py
Normal file
112
torch/contrib/_graph_vis.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""
|
||||
Experimental. Tools for visualizing the torch.jit.Graph objects.
|
||||
"""
|
||||
import string
|
||||
import json
|
||||
|
||||
_vis_template = string.Template("""
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>$name</title>
|
||||
|
||||
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.js"></script>
|
||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.css" rel="stylesheet" type="text/css" />
|
||||
|
||||
<style type="text/css">
|
||||
#mynetwork {
|
||||
height: 100vh;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<div id="mynetwork"></div>
|
||||
|
||||
<script type="text/javascript">
|
||||
// create an array with nodes
|
||||
var nodes = new vis.DataSet(
|
||||
$nodes
|
||||
);
|
||||
|
||||
// create an array with edges
|
||||
var edges = new vis.DataSet(
|
||||
$edges
|
||||
);
|
||||
|
||||
// create a network
|
||||
var container = document.getElementById('mynetwork');
|
||||
var data = {
|
||||
nodes: nodes,
|
||||
edges: edges
|
||||
};
|
||||
var options = $options;
|
||||
var network = new vis.Network(container, data, options);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
|
||||
def write(self, filename):
|
||||
"""
|
||||
Write an html file that visualizes a torch.jit.Graph using vis.js
|
||||
Arguments:
|
||||
self (torch.jit.Graph): the graph.
|
||||
filename (string): the output filename, an html-file.
|
||||
"""
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
options = {}
|
||||
for n, i in enumerate(self.inputs()):
|
||||
nodes.append({
|
||||
'id': i.unique(),
|
||||
'label': 'input {}'.format(n),
|
||||
'shape': 'square',
|
||||
})
|
||||
|
||||
existing = set()
|
||||
|
||||
def add_edge(i_, n):
|
||||
i = i_ if i_.kind() != 'Select' else i_.input()
|
||||
if (i, n) in existing:
|
||||
return
|
||||
existing.add((i, n))
|
||||
e = {
|
||||
'from': n.unique(),
|
||||
'to': i.unique(),
|
||||
'arrows': 'from',
|
||||
}
|
||||
if i.stage() != n.stage():
|
||||
e['color'] = 'green'
|
||||
edges.append(e)
|
||||
|
||||
counts = {}
|
||||
offset = 0
|
||||
for n in self.nodes():
|
||||
if len(n.uses()) == 0 or n.kind() == 'Select' or n.kind() == 'Undefined':
|
||||
continue
|
||||
ident = counts.get(n.kind(), 0)
|
||||
counts[n.kind()] = ident + 1
|
||||
d = {
|
||||
'id': n.unique(),
|
||||
'label': '{}_{}'.format(n.kind(), ident),
|
||||
'y': offset,
|
||||
'fixed': {'y': True},
|
||||
}
|
||||
if n in self.outputs():
|
||||
d['shape'] = 'triangle'
|
||||
|
||||
for i in n.inputs():
|
||||
add_edge(i, n)
|
||||
|
||||
nodes.append(d)
|
||||
offset += 30
|
||||
|
||||
result = _vis_template.substitute(nodes=json.dumps(nodes),
|
||||
edges=json.dumps(edges),
|
||||
options=json.dumps(options),
|
||||
name=filename)
|
||||
with open(filename, 'w') as f:
|
||||
f.write(result)
|
@ -7,6 +7,7 @@ import itertools
|
||||
import types
|
||||
import contextlib
|
||||
import os
|
||||
import torch.contrib._graph_vis as graph_vis
|
||||
# Example how to use:
|
||||
#
|
||||
# import torch.jit
|
||||
@ -97,7 +98,7 @@ def _dump_trace(trace_name, name, suffix, complete_trace):
|
||||
filename = "{}_{}_{}".format(trace_name, name, suffix)
|
||||
with open(filename + ".ir", "w") as f:
|
||||
f.write(str(complete_trace))
|
||||
complete_trace.graph().write_vis(filename + ".html")
|
||||
graph_vis.write(complete_trace.graph(), filename + ".html")
|
||||
|
||||
|
||||
# holds run() to run the function and self.inputs which
|
||||
|
103
torch/onnx.py
103
torch/onnx.py
@ -93,106 +93,3 @@ def _op(self, opname, *args, **kwargs):
|
||||
|
||||
|
||||
torch._C.Graph.op = _op
|
||||
|
||||
|
||||
_vis_template = string.Template("""
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Network | Basic usage</title>
|
||||
|
||||
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.js"></script>
|
||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.css" rel="stylesheet" type="text/css" />
|
||||
|
||||
<style type="text/css">
|
||||
#mynetwork {
|
||||
height: 100vh;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<div id="mynetwork"></div>
|
||||
|
||||
<script type="text/javascript">
|
||||
// create an array with nodes
|
||||
var nodes = new vis.DataSet(
|
||||
$nodes
|
||||
);
|
||||
|
||||
// create an array with edges
|
||||
var edges = new vis.DataSet(
|
||||
$edges
|
||||
);
|
||||
|
||||
// create a network
|
||||
var container = document.getElementById('mynetwork');
|
||||
var data = {
|
||||
nodes: nodes,
|
||||
edges: edges
|
||||
};
|
||||
var options = $options;
|
||||
var network = new vis.Network(container, data, options);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
|
||||
def _write_vis(self, filename):
|
||||
nodes = []
|
||||
edges = []
|
||||
options = {}
|
||||
for n, i in enumerate(self.inputs()):
|
||||
nodes.append({
|
||||
'id': i.unique(),
|
||||
'label': 'input {}'.format(n),
|
||||
'shape': 'square',
|
||||
})
|
||||
|
||||
existing = set()
|
||||
|
||||
def add_edge(i_, n):
|
||||
i = i_ if i_.kind() != 'Select' else i_.input()
|
||||
if (i, n) in existing:
|
||||
return
|
||||
existing.add((i, n))
|
||||
e = {
|
||||
'from': n.unique(),
|
||||
'to': i.unique(),
|
||||
'arrows': 'from',
|
||||
}
|
||||
if i.stage() != n.stage():
|
||||
e['color'] = 'green'
|
||||
edges.append(e)
|
||||
|
||||
counts = {}
|
||||
offset = 0
|
||||
for n in self.nodes():
|
||||
if len(n.uses()) == 0 or n.kind() == 'Select' or n.kind() == 'Undefined':
|
||||
continue
|
||||
ident = counts.get(n.kind(), 0)
|
||||
counts[n.kind()] = ident + 1
|
||||
d = {
|
||||
'id': n.unique(),
|
||||
'label': '{}_{}'.format(n.kind(), ident),
|
||||
'y': offset,
|
||||
'fixed': {'y': True},
|
||||
}
|
||||
if n in self.outputs():
|
||||
d['shape'] = 'triangle'
|
||||
|
||||
for i in n.inputs():
|
||||
add_edge(i, n)
|
||||
|
||||
nodes.append(d)
|
||||
offset += 30
|
||||
|
||||
result = _vis_template.substitute(nodes=json.dumps(nodes),
|
||||
edges=json.dumps(edges),
|
||||
options=json.dumps(options))
|
||||
with open(filename, 'w') as f:
|
||||
f.write(result)
|
||||
|
||||
|
||||
torch._C.Graph.write_vis = _write_vis
|
||||
|
Reference in New Issue
Block a user