Ignore shape inference exception from Caffe2 ATen fallback (#90408)

Fixes #87318

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90408
Approved by: https://github.com/BowenBao
This commit is contained in:
Thiago Crepaldi
2023-03-08 20:02:11 +00:00
committed by PyTorch MergeBot
parent c988de1040
commit b9c25f186c
2 changed files with 25 additions and 11 deletions

View File

@ -5,7 +5,6 @@ import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -87,10 +86,6 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
x = torch.ones(3) x = torch.ones(3)
torch.onnx.export(foo, (x,), f) torch.onnx.export(foo, (x,), f)
# TODO(87318): Can't pass even with Caffe2
@unittest.skip(
"RuntimeError: ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type"
)
@common_utils.skipIfNoCaffe2 @common_utils.skipIfNoCaffe2
@common_utils.skipIfNoLapack @common_utils.skipIfNoLapack
def test_caffe2_aten_fallback(self): def test_caffe2_aten_fallback(self):

View File

@ -686,9 +686,19 @@ def _optimize_graph(
graph = _C._jit_pass_canonicalize(graph) graph = _C._jit_pass_canonicalize(graph)
_C._jit_pass_lint(graph) _C._jit_pass_lint(graph)
if GLOBALS.onnx_shape_inference: if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference( try:
graph, params_dict, GLOBALS.export_onnx_opset_version _C._jit_pass_onnx_graph_shape_type_inference(
) graph, params_dict, GLOBALS.export_onnx_opset_version
)
except RuntimeError as exc:
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and exc.args[0]
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
):
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
pass
return graph return graph
@ -1183,9 +1193,18 @@ def _model_to_graph(
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if GLOBALS.onnx_shape_inference: if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference( try:
graph, params_dict, GLOBALS.export_onnx_opset_version _C._jit_pass_onnx_graph_shape_type_inference(
) graph, params_dict, GLOBALS.export_onnx_opset_version
)
except RuntimeError as exc:
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and exc.args[0]
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
):
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
pass
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)