mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Ignore shape inference exception from Caffe2 ATen fallback (#90408)
Fixes #87318 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90408 Approved by: https://github.com/BowenBao
This commit is contained in:
committed by
PyTorch MergeBot
parent
c988de1040
commit
b9c25f186c
@ -686,9 +686,19 @@ def _optimize_graph(
|
||||
graph = _C._jit_pass_canonicalize(graph)
|
||||
_C._jit_pass_lint(graph)
|
||||
if GLOBALS.onnx_shape_inference:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
try:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if (
|
||||
_C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
and exc.args[0]
|
||||
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
|
||||
):
|
||||
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
|
||||
pass
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
@ -1183,9 +1193,18 @@ def _model_to_graph(
|
||||
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
if GLOBALS.onnx_shape_inference:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
try:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if (
|
||||
_C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
and exc.args[0]
|
||||
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
|
||||
):
|
||||
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
|
||||
pass
|
||||
|
||||
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
||||
|
||||
|
Reference in New Issue
Block a user