Remove duplicated _optimize_trace and use core (#20394)

Summary:
The duplicated code of `_optimize_trace` in _pytorch_graph.py is used to bypass some optimization step which causes missing scope.

It seems that most of the problematic steps have been fixed recently. Standard models implemented in torchvision are visually inspected before the commit. However, the `+=` in 50d54a82d1/torchvision/models/resnet.py (L63) will let f4d9bfaa4d/torch/onnx/utils.py (L159) produce a bad result. It can be fixed by replacing it with `out += identity`. This also implies that `+=` has non-intuitive behavior.

cc orionr ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20394

Reviewed By: NarineK

Differential Revision: D15452204

Pulled By: orionr

fbshipit-source-id: eaa4c13f16551c78dc6419f1e22eb2c560af4cc5
This commit is contained in:
Tzu-Wei Huang
2019-05-22 18:27:24 -07:00
committed by Facebook Github Bot
parent 871c9dcb1d
commit 5952ca8d9f
2 changed files with 3 additions and 57 deletions

View File

@ -36,7 +36,7 @@ void RemoveInplaceOps(Block* block) {
// create a replacement out of place op
auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
newNode->insertBefore(node);
newNode->setScope(node->scope());
// copy inputs
for (auto input : node->inputs()) {
newNode->addInput(input);

View File

@ -11,7 +11,7 @@ from tensorboard.compat.proto.versions_pb2 import VersionDef
import torch
from ._proto_graph import node_proto
from torch.onnx.utils import OperatorExportTypes
from torch.onnx import _optimize_trace
methods_OP = ['attributeNames', 'hasMultipleOutputs', 'hasUses', 'inputs',
'kind', 'outputs', 'outputsSize', 'scopeName']
@ -247,60 +247,6 @@ def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_
"""
operator_export_type = getattr(OperatorExportTypes, operator_export_type)
# This code is similar to torch/onnx/utils.py, but adjusted to provide
# the most visually understandable output.
#
# For example, the commented out line
#
# # torch._C._jit_pass_onnx_peephole(graph).
#
# This pass removes a lot of scope information. The amount of optimization
# cannot be too much (lots of information lost) or too little (too much
# useless information), therefore I copy-pasted the code so that it will
# not be affected by torch/onnx/utils.py changes.
def _optimize_trace(trace, operator_export_type):
trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))
def _optimize_graph(graph, operator_export_type):
# torch._C._jit_pass_remove_inplace_ops(graph)
# we record now record some ops like ones/zeros
# into a trace where we previously recorded constants
# use constant prop to maintain our current level of onnx support
# without implementing symbolics for all of them
torch._C._jit_pass_constant_propagation(graph)
torch.onnx.utils._split_tensor_list_constants(graph, graph)
# run dce to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_canonicalize_ops(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lint(graph)
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
torch._C._jit_pass_prepare_division_for_onnx(graph)
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
# onnx does not support tuples, so try to remove them
torch._C._jit_pass_lower_all_tuples(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lint(graph)
if operator_export_type != OperatorExportTypes.RAW:
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_onnx_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_fixup_onnx_loops(graph)
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
return graph
with torch.onnx.set_training(model, False):
try:
@ -314,7 +260,7 @@ def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_
torch.onnx.export(
model, args, tempfile.TemporaryFile(), verbose=True)
except RuntimeError:
print("Your model fails onnx too, please report to onnx team")
print("Your model cannot be exported by onnx, please report to onnx team")
# Create an object matching
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
# The producer version has been reverse engineered from standard