mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables. * list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize. * Manually went back and made mypy happy after the change. * Also fixed style lints in files covered by flake8 but not by pyfmt Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980 Approved by: https://github.com/justinchuby, https://github.com/malfet
62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
# mypy: allow-untyped-defs
|
|
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
|
from tensorboard.compat.proto.node_def_pb2 import NodeDef
|
|
from tensorboard.compat.proto.versions_pb2 import VersionDef
|
|
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
|
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
|
|
|
|
|
def load_onnx_graph(fname):
|
|
import onnx
|
|
|
|
m = onnx.load(fname) # type: ignore[attr-defined]
|
|
g = m.graph
|
|
return parse(g)
|
|
|
|
|
|
def parse(graph):
|
|
nodes = []
|
|
import itertools
|
|
|
|
nodes_proto = list(itertools.chain(graph.input, graph.output))
|
|
|
|
for node in nodes_proto:
|
|
print(node.name)
|
|
shapeproto = TensorShapeProto(
|
|
dim=[
|
|
TensorShapeProto.Dim(size=d.dim_value)
|
|
for d in node.type.tensor_type.shape.dim
|
|
]
|
|
)
|
|
nodes.append(
|
|
NodeDef(
|
|
name=node.name.encode(encoding="utf_8"),
|
|
op="Variable",
|
|
input=[],
|
|
attr={
|
|
"dtype": AttrValue(type=node.type.tensor_type.elem_type),
|
|
"shape": AttrValue(shape=shapeproto),
|
|
},
|
|
)
|
|
)
|
|
|
|
for node in graph.node:
|
|
_attr = [" = ".join([str(f[1]) for f in s.ListFields()]) for s in node.attribute]
|
|
attr = ", ".join(_attr).encode(encoding="utf_8")
|
|
print(node.output[0])
|
|
nodes.append(
|
|
NodeDef(
|
|
name=node.output[0].encode(encoding="utf_8"),
|
|
op=node.op_type,
|
|
input=node.input,
|
|
attr={"parameters": AttrValue(s=attr)},
|
|
)
|
|
)
|
|
|
|
# two pass token replacement, appends opname to object id
|
|
mapping = {}
|
|
for node in nodes:
|
|
mapping[node.name] = node.op + "_" + node.name
|
|
|
|
return GraphDef(node=nodes, versions=VersionDef(producer=22))
|