[ONNX] add converters for sym_min, sym_max (#152196)

Conversion of Phi4-multimodel-instruct fails because of missing converters for torch.sym_max, and torch.sym_min.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152196
Approved by: https://github.com/justinchuby
This commit is contained in:
xadupre
2025-04-25 20:01:01 +00:00
committed by PyTorch MergeBot
parent 9336608307
commit 91c590f048
2 changed files with 41 additions and 1 deletions

View File

@ -542,6 +542,34 @@ class DynamoExporterTest(common_utils.TestCase):
# make sre the naming is working
self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx")
def test_export_sym_max(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.sym_max(*x.shape)
inputs = (torch.zeros((2, 3)),)
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
self.assertIn(
"Max",
[node.op_type for node in onnx_program.model.graph],
)
def test_export_sym_min(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.sym_min(*x.shape)
inputs = (torch.zeros((2, 3)),)
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
self.assertIn(
"Min",
[node.op_type for node in onnx_program.model.graph],
)
def test_export_sym_not(self):
class SymNotModel(torch.nn.Module):
def forward(self, x):

View File

@ -9,10 +9,22 @@ from __future__ import annotations
from onnxscript.onnx_opset import opset18 as op
import torch
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL, IntType
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
@onnx_impl(torch.sym_max, trace_only=True)
def sym_max(x: IntType, y: IntType) -> IntType:
"""sym_max(SymInt x, SymInt y) -> SymInt"""
return op.Max(x, y)
@onnx_impl(torch.sym_min, trace_only=True)
def sym_min(x: IntType, y: IntType) -> IntType:
"""sym_min(SymInt x, SymInt y) -> SymInt"""
return op.Min(x, y)
@onnx_impl(torch.sym_not, trace_only=True)
def sym_not(self: BOOL) -> BOOL:
"""sym_not(SymBool self) -> SymBool"""