mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove duplicated _optimize_trace and use core (#20394)
Summary: The duplicated code of `_optimize_trace` in _pytorch_graph.py is used to bypass some optimization step which causes missing scope. It seems that most of the problematic steps have been fixed recently. Standard models implemented in torchvision are visually inspected before the commit. However, the `+=` in50d54a82d1/torchvision/models/resnet.py (L63)
will letf4d9bfaa4d/torch/onnx/utils.py (L159)
produce a bad result. It can be fixed by replacing it with `out += identity`. This also implies that `+=` has non-intuitive behavior. cc orionr ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/20394 Reviewed By: NarineK Differential Revision: D15452204 Pulled By: orionr fbshipit-source-id: eaa4c13f16551c78dc6419f1e22eb2c560af4cc5
This commit is contained in:
committed by
Facebook Github Bot
parent
871c9dcb1d
commit
5952ca8d9f
@ -36,7 +36,7 @@ void RemoveInplaceOps(Block* block) {
|
||||
// create a replacement out of place op
|
||||
auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
|
||||
newNode->insertBefore(node);
|
||||
|
||||
newNode->setScope(node->scope());
|
||||
// copy inputs
|
||||
for (auto input : node->inputs()) {
|
||||
newNode->addInput(input);
|
||||
|
@ -11,7 +11,7 @@ from tensorboard.compat.proto.versions_pb2 import VersionDef
|
||||
import torch
|
||||
from ._proto_graph import node_proto
|
||||
from torch.onnx.utils import OperatorExportTypes
|
||||
|
||||
from torch.onnx import _optimize_trace
|
||||
|
||||
methods_OP = ['attributeNames', 'hasMultipleOutputs', 'hasUses', 'inputs',
|
||||
'kind', 'outputs', 'outputsSize', 'scopeName']
|
||||
@ -247,60 +247,6 @@ def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_
|
||||
"""
|
||||
operator_export_type = getattr(OperatorExportTypes, operator_export_type)
|
||||
|
||||
# This code is similar to torch/onnx/utils.py, but adjusted to provide
|
||||
# the most visually understandable output.
|
||||
#
|
||||
# For example, the commented out line
|
||||
#
|
||||
# # torch._C._jit_pass_onnx_peephole(graph).
|
||||
#
|
||||
# This pass removes a lot of scope information. The amount of optimization
|
||||
# cannot be too much (lots of information lost) or too little (too much
|
||||
# useless information), therefore I copy-pasted the code so that it will
|
||||
# not be affected by torch/onnx/utils.py changes.
|
||||
def _optimize_trace(trace, operator_export_type):
|
||||
trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))
|
||||
|
||||
def _optimize_graph(graph, operator_export_type):
|
||||
# torch._C._jit_pass_remove_inplace_ops(graph)
|
||||
# we record now record some ops like ones/zeros
|
||||
# into a trace where we previously recorded constants
|
||||
# use constant prop to maintain our current level of onnx support
|
||||
# without implementing symbolics for all of them
|
||||
torch._C._jit_pass_constant_propagation(graph)
|
||||
torch.onnx.utils._split_tensor_list_constants(graph, graph)
|
||||
# run dce to eliminate dead parts of the graph that might have been
|
||||
# left behind by things like symbolic_override
|
||||
torch._C._jit_pass_dce(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
# torch._C._jit_pass_canonicalize_ops(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
torch._C._jit_pass_peephole(graph, True)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
|
||||
torch._C._jit_pass_prepare_division_for_onnx(graph)
|
||||
# onnx only supports tensors, so we turn all out number types into tensors
|
||||
torch._C._jit_pass_erase_number_types(graph)
|
||||
# onnx does not support tuples, so try to remove them
|
||||
torch._C._jit_pass_lower_all_tuples(graph)
|
||||
torch._C._jit_pass_peephole(graph, True)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
if operator_export_type != OperatorExportTypes.RAW:
|
||||
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
# torch._C._jit_pass_onnx_peephole(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
torch._C._jit_pass_dce(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
torch._C._jit_pass_fixup_onnx_loops(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
graph = torch._C._jit_pass_canonicalize(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
return graph
|
||||
|
||||
with torch.onnx.set_training(model, False):
|
||||
try:
|
||||
@ -314,7 +260,7 @@ def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_
|
||||
torch.onnx.export(
|
||||
model, args, tempfile.TemporaryFile(), verbose=True)
|
||||
except RuntimeError:
|
||||
print("Your model fails onnx too, please report to onnx team")
|
||||
print("Your model cannot be exported by onnx, please report to onnx team")
|
||||
# Create an object matching
|
||||
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
|
||||
# The producer version has been reverse engineered from standard
|
||||
|
Reference in New Issue
Block a user