mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR fixes the problem of having the `Where` operator bound to different types in cases where the dtype is not explicitly set. The PR extends the implicit casting to the onnx::Where operator to fix the issue, and includes the corresponding unit test. Fixes #118733 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120619 Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
200 lines
6.3 KiB
Python
200 lines
6.3 KiB
Python
# Owner(s): ["module: onnx"]
|
|
import onnxruntime
|
|
import pytorch_test_common
|
|
|
|
import torch
|
|
from pytorch_test_common import skipIfNoCuda
|
|
from torch.onnx import verification
|
|
from torch.onnx._globals import GLOBALS
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
|
def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
|
|
r"""
|
|
This function exports torch::jit::Graph object
|
|
to serialized ONNX ModelProto.
|
|
This function is for testing purpose.
|
|
It only keeps the essential parts for IR graph conversions.
|
|
It also does not interact with actual PyTorch modules nor
|
|
PyTorch tensor inputs.
|
|
"""
|
|
|
|
GLOBALS.export_onnx_opset_version = opset_version
|
|
graph = torch.onnx.utils._optimize_graph(
|
|
graph, operator_export_type, params_dict={}
|
|
)
|
|
proto, _, _, _ = graph._export_onnx(
|
|
{},
|
|
opset_version,
|
|
{},
|
|
False,
|
|
operator_export_type,
|
|
False,
|
|
False,
|
|
{},
|
|
True,
|
|
"",
|
|
{},
|
|
)
|
|
return proto
|
|
|
|
|
|
class _TestJITIRToONNX:
|
|
"""Abstract base class for test cases.
|
|
|
|
Intentionally not a sub-class of unittest.TestCase so that unittest / pytest
|
|
don't run it directly. unitest.TestCase is mixed in as another base class when
|
|
creating concrete sub-types. See MakeTestCase().
|
|
"""
|
|
|
|
opset_version = -1 # Sub-classes must override
|
|
ort_providers = ["CPUExecutionProvider"]
|
|
check_shape = True
|
|
check_dtype = True
|
|
ignore_none = True # True for tracing, and Flase for scripting
|
|
|
|
def run_test(self, graph_ir, example_inputs, parse_tensor_constants=False):
|
|
graph = torch._C.parse_ir(graph_ir, parse_tensor_constants)
|
|
jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)
|
|
|
|
onnx_proto = _jit_graph_to_onnx_model(
|
|
graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version
|
|
)
|
|
ort_sess = onnxruntime.InferenceSession(
|
|
onnx_proto, providers=self.ort_providers
|
|
)
|
|
ort_outs = verification._run_onnx(ort_sess, example_inputs)
|
|
|
|
options = verification.VerificationOptions(
|
|
rtol=1e-3,
|
|
atol=1e-7,
|
|
check_shape=self.check_shape,
|
|
check_dtype=self.check_dtype,
|
|
ignore_none=self.ignore_none,
|
|
acceptable_error_percentage=None,
|
|
)
|
|
verification._compare_onnx_pytorch_outputs(
|
|
ort_outs,
|
|
jit_outs,
|
|
options,
|
|
)
|
|
|
|
def test_example_ir(self):
|
|
graph_ir = """
|
|
graph(%1 : Float(2, 3),
|
|
%2 : Float(2, 3)):
|
|
%3 : int = prim::Constant[value=1]()
|
|
%4 : Float(2, 3) = aten::add(%1, %2, %3)
|
|
return (%4)
|
|
"""
|
|
a = torch.randn(2, 3)
|
|
b = torch.randn(2, 3)
|
|
self.run_test(graph_ir, (a, b))
|
|
|
|
def test_where_constants(self):
|
|
graph_ir = """
|
|
graph(%0 : Bool(8, device=cpu),
|
|
%1 : Float(8, device=cpu)):
|
|
%3 : Double(device=cpu) = prim::Constant[value={0.}]()
|
|
%4 : Float(8) = aten::where(%0, %1, %3)
|
|
return (%4)
|
|
"""
|
|
a = torch.zeros(8, dtype=bool)
|
|
b = torch.zeros(8)
|
|
self.run_test(graph_ir, (a, b), parse_tensor_constants=True)
|
|
|
|
def test_add_sub_with_graph_inputs(self):
|
|
for op in ["add", "sub", "rsub"]:
|
|
graph_ir = f"""
|
|
graph(%1 : Float(2, 3),
|
|
%2 : Float(2, 3),
|
|
%3 : int):
|
|
%4 : Float(2, 3) = aten::{op}(%1, %2, %3)
|
|
return (%4)
|
|
"""
|
|
a = torch.randn(2, 3)
|
|
b = torch.randn(2, 3)
|
|
self.run_test(graph_ir, (a, b, 2))
|
|
|
|
def test_native_layer_norm(self):
|
|
graph_ir = """
|
|
graph(%x : Float(2, 3, 2),
|
|
%w : Float(3, 2),
|
|
%b : Float(3, 2)):
|
|
%5 : int = prim::Constant[value=3]()
|
|
%6 : int = prim::Constant[value=2]()
|
|
%7 : int[] = prim::ListConstruct(%5, %6)
|
|
%10 : float = prim::Constant[value=1.0000000000000001e-05]()
|
|
%11 : Float(2, 3, 2), %12 : Float(2, 1, 1), %13 : Float(2, 1, 1) = aten::native_layer_norm(%x, %7, %w, %b, %10)
|
|
return (%11, %12, %13)
|
|
"""
|
|
x = torch.randn(2, 3, 2)
|
|
w = torch.randn(3, 2)
|
|
b = torch.randn(3, 2)
|
|
self.run_test(graph_ir, (x, w, b))
|
|
|
|
def test_convolution(self):
|
|
graph_ir = """
|
|
graph(%1 : Tensor,
|
|
%2 : Tensor):
|
|
%3 : NoneType = prim::Constant()
|
|
%4 : int[] = prim::Constant[value=[1, 1]]()
|
|
%5 : int[] = prim::Constant[value=[0, 0]]()
|
|
%6 : bool = prim::Constant[value=0]()
|
|
%7 : int = prim::Constant[value=1]()
|
|
%8 : Tensor = aten::convolution(%1, %2, %3, %4, %5, %4, %6, %5, %7)
|
|
return (%8)
|
|
"""
|
|
x = torch.randn(8, 1, 5, 5)
|
|
w = torch.randn(4, 1, 3, 3)
|
|
self.run_test(graph_ir, (x, w))
|
|
|
|
def test_log_softmax(self):
|
|
graph_ir = """
|
|
graph(%x: Tensor):
|
|
%half_to_float: bool = prim::Constant[value=0]()
|
|
%dim: int = prim::Constant[value=1]()
|
|
%y = aten::_log_softmax(%x, %dim, %half_to_float)
|
|
return (%y)
|
|
"""
|
|
x = torch.randn(5, 2)
|
|
self.run_test(graph_ir, (x,))
|
|
|
|
@skipIfNoCuda
|
|
def test_log_softmax_half_to_float(self):
|
|
graph_ir = """
|
|
graph(%x: Tensor):
|
|
%half_to_float: bool = prim::Constant[value=1]()
|
|
%dim: int = prim::Constant[value=1]()
|
|
%y = aten::_log_softmax(%x, %dim, %half_to_float)
|
|
return (%y)
|
|
"""
|
|
x = torch.randn(5, 2).half().to("cuda")
|
|
self.run_test(graph_ir, (x,))
|
|
|
|
def test_native_dropout(self):
|
|
graph_ir = """
|
|
graph(%1 : Float(2, 3)):
|
|
%2 : float = prim::Constant[value=0.0]()
|
|
%training : bool = prim::Constant[value=1]()
|
|
%3 : Tensor, %4 : Tensor = aten::native_dropout(%1, %2, %training)
|
|
return (%3, %4)
|
|
"""
|
|
a = torch.randn(2, 3)
|
|
self.run_test(graph_ir, (a,))
|
|
|
|
|
|
def MakeTestCase(opset_version: int) -> type:
|
|
name = f"TestJITIRToONNX_opset{opset_version}"
|
|
return type(
|
|
str(name),
|
|
(pytorch_test_common.ExportTestCase,),
|
|
dict(_TestJITIRToONNX.__dict__, opset_version=opset_version),
|
|
)
|
|
|
|
|
|
TestJITIRToONNX_opset14 = MakeTestCase(14)
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|