diff --git a/test/onnx/exporter/test_small_models_e2e.py b/test/onnx/exporter/test_small_models_e2e.py index 98d2c468a7c8..e943f3f4c2be 100644 --- a/test/onnx/exporter/test_small_models_e2e.py +++ b/test/onnx/exporter/test_small_models_e2e.py @@ -542,6 +542,34 @@ class DynamoExporterTest(common_utils.TestCase): # make sre the naming is working self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx") + def test_export_sym_max(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.sym_max(*x.shape) + + inputs = (torch.zeros((2, 3)),) + dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},) + onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes) + onnx_testing.assert_onnx_program(onnx_program, args=inputs) + self.assertIn( + "Max", + [node.op_type for node in onnx_program.model.graph], + ) + + def test_export_sym_min(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.sym_min(*x.shape) + + inputs = (torch.zeros((2, 3)),) + dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},) + onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes) + onnx_testing.assert_onnx_program(onnx_program, args=inputs) + self.assertIn( + "Min", + [node.op_type for node in onnx_program.model.graph], + ) + def test_export_sym_not(self): class SymNotModel(torch.nn.Module): def forward(self, x): diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/symops.py b/torch/onnx/_internal/exporter/_torchlib/ops/symops.py index 11451ed6ed24..0641663d56fa 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/symops.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/symops.py @@ -9,10 +9,22 @@ from __future__ import annotations from onnxscript.onnx_opset import opset18 as op import torch -from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL +from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL, IntType from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl +@onnx_impl(torch.sym_max, trace_only=True) +def sym_max(x: IntType, y: IntType) -> IntType: + """sym_max(SymInt x, SymInt y) -> SymInt""" + return op.Max(x, y) + + +@onnx_impl(torch.sym_min, trace_only=True) +def sym_min(x: IntType, y: IntType) -> IntType: + """sym_min(SymInt x, SymInt y) -> SymInt""" + return op.Min(x, y) + + @onnx_impl(torch.sym_not, trace_only=True) def sym_not(self: BOOL) -> BOOL: """sym_not(SymBool self) -> SymBool"""