Ignore shape inference exception from Caffe2 ATen fallback (#90408)

Fixes #87318

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90408
Approved by: https://github.com/BowenBao
This commit is contained in:
Thiago Crepaldi
2023-03-08 20:02:11 +00:00
committed by PyTorch MergeBot
parent c988de1040
commit b9c25f186c
2 changed files with 25 additions and 11 deletions

View File

@ -686,9 +686,19 @@ def _optimize_graph(
graph = _C._jit_pass_canonicalize(graph)
_C._jit_pass_lint(graph)
if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
try:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
except RuntimeError as exc:
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and exc.args[0]
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
):
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
pass
return graph
@ -1183,9 +1193,18 @@ def _model_to_graph(
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
try:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
except RuntimeError as exc:
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and exc.args[0]
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
):
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
pass
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)