mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Summary: There is a module called `2to3` which you can target for future specifically to remove these, the directory of `caffe2` has the most redundant imports: ```2to3 -f future -w caffe2``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/45033 Reviewed By: seemethere Differential Revision: D23808648 Pulled By: bugra fbshipit-source-id: 38971900f0fe43ab44a9168e57f2307580d36a38
413 lines
14 KiB
Python
413 lines
14 KiB
Python
## @package net_drawer
|
|
# Module caffe2.python.net_drawer
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
from collections import defaultdict
|
|
from caffe2.python import utils
|
|
from future.utils import viewitems
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
try:
|
|
import pydot
|
|
except ImportError:
|
|
logger.info(
|
|
'Cannot import pydot, which is required for drawing a network. This '
|
|
'can usually be installed in python with "pip install pydot". Also, '
|
|
'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
|
|
'can usually be installed with "sudo apt-get install graphviz".'
|
|
)
|
|
print(
|
|
'net_drawer will not run correctly. Please install the correct '
|
|
'dependencies.'
|
|
)
|
|
pydot = None
|
|
|
|
from caffe2.proto import caffe2_pb2
|
|
|
|
OP_STYLE = {
|
|
'shape': 'box',
|
|
'color': '#0F9D58',
|
|
'style': 'filled',
|
|
'fontcolor': '#FFFFFF'
|
|
}
|
|
BLOB_STYLE = {'shape': 'octagon'}
|
|
|
|
|
|
def _rectify_operator_and_name(operators_or_net, name):
|
|
"""Gets the operators and name for the pydot graph."""
|
|
if isinstance(operators_or_net, caffe2_pb2.NetDef):
|
|
operators = operators_or_net.op
|
|
if name is None:
|
|
name = operators_or_net.name
|
|
elif hasattr(operators_or_net, 'Proto'):
|
|
net = operators_or_net.Proto()
|
|
if not isinstance(net, caffe2_pb2.NetDef):
|
|
raise RuntimeError(
|
|
"Expecting NetDef, but got {}".format(type(net)))
|
|
operators = net.op
|
|
if name is None:
|
|
name = net.name
|
|
else:
|
|
operators = operators_or_net
|
|
if name is None:
|
|
name = "unnamed"
|
|
return operators, name
|
|
|
|
|
|
def _escape_label(name):
|
|
# json.dumps is poor man's escaping
|
|
return json.dumps(name)
|
|
|
|
|
|
def GetOpNodeProducer(append_output, **kwargs):
|
|
def ReallyGetOpNode(op, op_id):
|
|
if op.name:
|
|
node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
|
|
else:
|
|
node_name = '%s (op#%d)' % (op.type, op_id)
|
|
if append_output:
|
|
for output_name in op.output:
|
|
node_name += '\n' + output_name
|
|
return pydot.Node(node_name, **kwargs)
|
|
return ReallyGetOpNode
|
|
|
|
|
|
def GetBlobNodeProducer(**kwargs):
|
|
def ReallyGetBlobNode(node_name, label):
|
|
return pydot.Node(node_name, label=label, **kwargs)
|
|
return ReallyGetBlobNode
|
|
|
|
def GetPydotGraph(
|
|
operators_or_net,
|
|
name=None,
|
|
rankdir='LR',
|
|
op_node_producer=None,
|
|
blob_node_producer=None
|
|
):
|
|
if op_node_producer is None:
|
|
op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
|
|
if blob_node_producer is None:
|
|
blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
|
|
operators, name = _rectify_operator_and_name(operators_or_net, name)
|
|
graph = pydot.Dot(name, rankdir=rankdir)
|
|
pydot_nodes = {}
|
|
pydot_node_counts = defaultdict(int)
|
|
for op_id, op in enumerate(operators):
|
|
op_node = op_node_producer(op, op_id)
|
|
graph.add_node(op_node)
|
|
# print 'Op: %s' % op.name
|
|
# print 'inputs: %s' % str(op.input)
|
|
# print 'outputs: %s' % str(op.output)
|
|
for input_name in op.input:
|
|
if input_name not in pydot_nodes:
|
|
input_node = blob_node_producer(
|
|
_escape_label(
|
|
input_name + str(pydot_node_counts[input_name])),
|
|
label=_escape_label(input_name),
|
|
)
|
|
pydot_nodes[input_name] = input_node
|
|
else:
|
|
input_node = pydot_nodes[input_name]
|
|
graph.add_node(input_node)
|
|
graph.add_edge(pydot.Edge(input_node, op_node))
|
|
for output_name in op.output:
|
|
if output_name in pydot_nodes:
|
|
# we are overwriting an existing blob. need to update the count.
|
|
pydot_node_counts[output_name] += 1
|
|
output_node = blob_node_producer(
|
|
_escape_label(
|
|
output_name + str(pydot_node_counts[output_name])),
|
|
label=_escape_label(output_name),
|
|
)
|
|
pydot_nodes[output_name] = output_node
|
|
graph.add_node(output_node)
|
|
graph.add_edge(pydot.Edge(op_node, output_node))
|
|
return graph
|
|
|
|
|
|
def GetPydotGraphMinimal(
|
|
operators_or_net,
|
|
name=None,
|
|
rankdir='LR',
|
|
minimal_dependency=False,
|
|
op_node_producer=None,
|
|
):
|
|
"""Different from GetPydotGraph, hide all blob nodes and only show op nodes.
|
|
|
|
If minimal_dependency is set as well, for each op, we will only draw the
|
|
edges to the minimal necessary ancestors. For example, if op c depends on
|
|
op a and b, and op b depends on a, then only the edge b->c will be drawn
|
|
because a->c will be implied.
|
|
"""
|
|
if op_node_producer is None:
|
|
op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
|
|
operators, name = _rectify_operator_and_name(operators_or_net, name)
|
|
graph = pydot.Dot(name, rankdir=rankdir)
|
|
# blob_parents maps each blob name to its generating op.
|
|
blob_parents = {}
|
|
# op_ancestry records the ancestors of each op.
|
|
op_ancestry = defaultdict(set)
|
|
for op_id, op in enumerate(operators):
|
|
op_node = op_node_producer(op, op_id)
|
|
graph.add_node(op_node)
|
|
# Get parents, and set up op ancestry.
|
|
parents = [
|
|
blob_parents[input_name] for input_name in op.input
|
|
if input_name in blob_parents
|
|
]
|
|
op_ancestry[op_node].update(parents)
|
|
for node in parents:
|
|
op_ancestry[op_node].update(op_ancestry[node])
|
|
if minimal_dependency:
|
|
# only add nodes that do not have transitive ancestry
|
|
for node in parents:
|
|
if all(
|
|
[node not in op_ancestry[other_node]
|
|
for other_node in parents]
|
|
):
|
|
graph.add_edge(pydot.Edge(node, op_node))
|
|
else:
|
|
# Add all parents to the graph.
|
|
for node in parents:
|
|
graph.add_edge(pydot.Edge(node, op_node))
|
|
# Update blob_parents to reflect that this op created the blobs.
|
|
for output_name in op.output:
|
|
blob_parents[output_name] = op_node
|
|
return graph
|
|
|
|
|
|
def GetOperatorMapForPlan(plan_def):
|
|
operator_map = {}
|
|
for net_id, net in enumerate(plan_def.network):
|
|
if net.HasField('name'):
|
|
operator_map[plan_def.name + "_" + net.name] = net.op
|
|
else:
|
|
operator_map[plan_def.name + "_network_%d" % net_id] = net.op
|
|
return operator_map
|
|
|
|
|
|
def _draw_nets(nets, g):
|
|
nodes = []
|
|
for i, net in enumerate(nets):
|
|
nodes.append(pydot.Node(_escape_label(net)))
|
|
g.add_node(nodes[-1])
|
|
if i > 0:
|
|
g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
|
|
return nodes
|
|
|
|
|
|
def _draw_steps(steps, g, skip_step_edges=False): # noqa
|
|
kMaxParallelSteps = 3
|
|
|
|
def get_label():
|
|
label = [step.name + '\n']
|
|
if step.report_net:
|
|
label.append('Reporter: {}'.format(step.report_net))
|
|
if step.should_stop_blob:
|
|
label.append('Stopper: {}'.format(step.should_stop_blob))
|
|
if step.concurrent_substeps:
|
|
label.append('Concurrent')
|
|
if step.only_once:
|
|
label.append('Once')
|
|
return '\n'.join(label)
|
|
|
|
def substep_edge(start, end):
|
|
return pydot.Edge(start, end, arrowhead='dot', style='dashed')
|
|
|
|
nodes = []
|
|
for i, step in enumerate(steps):
|
|
parallel = step.concurrent_substeps
|
|
|
|
nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
|
|
g.add_node(nodes[-1])
|
|
|
|
if i > 0 and not skip_step_edges:
|
|
g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
|
|
|
|
if step.network:
|
|
sub_nodes = _draw_nets(step.network, g)
|
|
elif step.substep:
|
|
if parallel:
|
|
sub_nodes = _draw_steps(
|
|
step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
|
|
else:
|
|
sub_nodes = _draw_steps(step.substep, g)
|
|
else:
|
|
raise ValueError('invalid step')
|
|
|
|
if parallel:
|
|
for sn in sub_nodes:
|
|
g.add_edge(substep_edge(nodes[-1], sn))
|
|
if len(step.substep) > kMaxParallelSteps:
|
|
ellipsis = pydot.Node('{} more steps'.format(
|
|
len(step.substep) - kMaxParallelSteps), **OP_STYLE)
|
|
g.add_node(ellipsis)
|
|
g.add_edge(substep_edge(nodes[-1], ellipsis))
|
|
else:
|
|
g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
|
|
|
|
return nodes
|
|
|
|
|
|
def GetPlanGraph(plan_def, name=None, rankdir='TB'):
|
|
graph = pydot.Dot(name, rankdir=rankdir)
|
|
_draw_steps(plan_def.execution_step, graph)
|
|
return graph
|
|
|
|
|
|
def GetGraphInJson(operators_or_net, output_filepath):
|
|
operators, _ = _rectify_operator_and_name(operators_or_net, None)
|
|
blob_strid_to_node_id = {}
|
|
node_name_counts = defaultdict(int)
|
|
nodes = []
|
|
edges = []
|
|
for op_id, op in enumerate(operators):
|
|
op_label = op.name + '/' + op.type if op.name else op.type
|
|
op_node_id = len(nodes)
|
|
nodes.append({
|
|
'id': op_node_id,
|
|
'label': op_label,
|
|
'op_id': op_id,
|
|
'type': 'op'
|
|
})
|
|
for input_name in op.input:
|
|
strid = _escape_label(
|
|
input_name + str(node_name_counts[input_name]))
|
|
if strid not in blob_strid_to_node_id:
|
|
input_node = {
|
|
'id': len(nodes),
|
|
'label': input_name,
|
|
'type': 'blob'
|
|
}
|
|
blob_strid_to_node_id[strid] = len(nodes)
|
|
nodes.append(input_node)
|
|
else:
|
|
input_node = nodes[blob_strid_to_node_id[strid]]
|
|
edges.append({
|
|
'source': blob_strid_to_node_id[strid],
|
|
'target': op_node_id
|
|
})
|
|
for output_name in op.output:
|
|
strid = _escape_label(
|
|
output_name + str(node_name_counts[output_name]))
|
|
if strid in blob_strid_to_node_id:
|
|
# we are overwriting an existing blob. need to update the count.
|
|
node_name_counts[output_name] += 1
|
|
strid = _escape_label(
|
|
output_name + str(node_name_counts[output_name]))
|
|
|
|
if strid not in blob_strid_to_node_id:
|
|
output_node = {
|
|
'id': len(nodes),
|
|
'label': output_name,
|
|
'type': 'blob'
|
|
}
|
|
blob_strid_to_node_id[strid] = len(nodes)
|
|
nodes.append(output_node)
|
|
edges.append({
|
|
'source': op_node_id,
|
|
'target': blob_strid_to_node_id[strid]
|
|
})
|
|
|
|
with open(output_filepath, 'w') as f:
|
|
json.dump({'nodes': nodes, 'edges': edges}, f)
|
|
|
|
|
|
# A dummy minimal PNG image used by GetGraphPngSafe as a
|
|
# placeholder when rendering fail to run.
|
|
_DummyPngImage = (
|
|
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
|
|
b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
|
|
b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
|
|
|
|
|
|
def GetGraphPngSafe(func, *args, **kwargs):
|
|
"""
|
|
Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
|
|
and empty image instead of throwing Exception
|
|
"""
|
|
try:
|
|
graph = func(*args, **kwargs)
|
|
if not isinstance(graph, pydot.Dot):
|
|
raise ValueError("func is expected to return pydot.Dot")
|
|
return graph.create_png()
|
|
except Exception as e:
|
|
logger.error("Failed to draw graph: {}".format(e))
|
|
return _DummyPngImage
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
|
|
parser.add_argument(
|
|
"--input",
|
|
type=str, required=True,
|
|
help="The input protobuf file."
|
|
)
|
|
parser.add_argument(
|
|
"--output_prefix",
|
|
type=str, default="",
|
|
help="The prefix to be added to the output filename."
|
|
)
|
|
parser.add_argument(
|
|
"--minimal", action="store_true",
|
|
help="If set, produce a minimal visualization."
|
|
)
|
|
parser.add_argument(
|
|
"--minimal_dependency", action="store_true",
|
|
help="If set, only draw minimal dependency."
|
|
)
|
|
parser.add_argument(
|
|
"--append_output", action="store_true",
|
|
help="If set, append the output blobs to the operator names.")
|
|
parser.add_argument(
|
|
"--rankdir", type=str, default="LR",
|
|
help="The rank direction of the pydot graph."
|
|
)
|
|
args = parser.parse_args()
|
|
with open(args.input, 'r') as fid:
|
|
content = fid.read()
|
|
graphs = utils.GetContentFromProtoString(
|
|
content, {
|
|
caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
|
|
caffe2_pb2.NetDef: lambda x: {x.name: x.op},
|
|
}
|
|
)
|
|
for key, operators in viewitems(graphs):
|
|
if args.minimal:
|
|
graph = GetPydotGraphMinimal(
|
|
operators,
|
|
name=key,
|
|
rankdir=args.rankdir,
|
|
node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
|
|
minimal_dependency=args.minimal_dependency)
|
|
else:
|
|
graph = GetPydotGraph(
|
|
operators,
|
|
name=key,
|
|
rankdir=args.rankdir,
|
|
node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
|
|
filename = args.output_prefix + graph.get_name() + '.dot'
|
|
graph.write(filename, format='raw')
|
|
pdf_filename = filename[:-3] + 'pdf'
|
|
try:
|
|
graph.write_pdf(pdf_filename)
|
|
except Exception:
|
|
print(
|
|
'Error when writing out the pdf file. Pydot requires graphviz '
|
|
'to convert dot files to pdf, and you may not have installed '
|
|
'graphviz. On ubuntu this can usually be installed with "sudo '
|
|
'apt-get install graphviz". We have generated the .dot file '
|
|
'but will not be able to generate pdf file for now.'
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|