mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Ignore shape inference exception from Caffe2 ATen fallback (#90408)
Fixes #87318 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90408 Approved by: https://github.com/BowenBao
This commit is contained in:
committed by
PyTorch MergeBot
parent
c988de1040
commit
b9c25f186c
@ -5,7 +5,6 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -87,10 +86,6 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
|
|||||||
x = torch.ones(3)
|
x = torch.ones(3)
|
||||||
torch.onnx.export(foo, (x,), f)
|
torch.onnx.export(foo, (x,), f)
|
||||||
|
|
||||||
# TODO(87318): Can't pass even with Caffe2
|
|
||||||
@unittest.skip(
|
|
||||||
"RuntimeError: ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type"
|
|
||||||
)
|
|
||||||
@common_utils.skipIfNoCaffe2
|
@common_utils.skipIfNoCaffe2
|
||||||
@common_utils.skipIfNoLapack
|
@common_utils.skipIfNoLapack
|
||||||
def test_caffe2_aten_fallback(self):
|
def test_caffe2_aten_fallback(self):
|
||||||
|
|||||||
@ -686,9 +686,19 @@ def _optimize_graph(
|
|||||||
graph = _C._jit_pass_canonicalize(graph)
|
graph = _C._jit_pass_canonicalize(graph)
|
||||||
_C._jit_pass_lint(graph)
|
_C._jit_pass_lint(graph)
|
||||||
if GLOBALS.onnx_shape_inference:
|
if GLOBALS.onnx_shape_inference:
|
||||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
try:
|
||||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||||
)
|
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if (
|
||||||
|
_C_onnx._CAFFE2_ATEN_FALLBACK
|
||||||
|
and exc.args[0]
|
||||||
|
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
|
||||||
|
):
|
||||||
|
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
|
||||||
|
pass
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
@ -1183,9 +1193,18 @@ def _model_to_graph(
|
|||||||
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||||
|
|
||||||
if GLOBALS.onnx_shape_inference:
|
if GLOBALS.onnx_shape_inference:
|
||||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
try:
|
||||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||||
)
|
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if (
|
||||||
|
_C_onnx._CAFFE2_ATEN_FALLBACK
|
||||||
|
and exc.args[0]
|
||||||
|
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
|
||||||
|
):
|
||||||
|
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
|
||||||
|
pass
|
||||||
|
|
||||||
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user