mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] add converters for sym_min, sym_max (#152196)
Conversion of Phi4-multimodel-instruct fails because of missing converters for torch.sym_max, and torch.sym_min. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152196 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
9336608307
commit
91c590f048
@ -542,6 +542,34 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
# make sre the naming is working
|
||||
self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx")
|
||||
|
||||
def test_export_sym_max(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sym_max(*x.shape)
|
||||
|
||||
inputs = (torch.zeros((2, 3)),)
|
||||
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
|
||||
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
self.assertIn(
|
||||
"Max",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
)
|
||||
|
||||
def test_export_sym_min(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sym_min(*x.shape)
|
||||
|
||||
inputs = (torch.zeros((2, 3)),)
|
||||
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
|
||||
onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
self.assertIn(
|
||||
"Min",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
)
|
||||
|
||||
def test_export_sym_not(self):
|
||||
class SymNotModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -9,10 +9,22 @@ from __future__ import annotations
|
||||
from onnxscript.onnx_opset import opset18 as op
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL, IntType
|
||||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||
|
||||
|
||||
@onnx_impl(torch.sym_max, trace_only=True)
|
||||
def sym_max(x: IntType, y: IntType) -> IntType:
|
||||
"""sym_max(SymInt x, SymInt y) -> SymInt"""
|
||||
return op.Max(x, y)
|
||||
|
||||
|
||||
@onnx_impl(torch.sym_min, trace_only=True)
|
||||
def sym_min(x: IntType, y: IntType) -> IntType:
|
||||
"""sym_min(SymInt x, SymInt y) -> SymInt"""
|
||||
return op.Min(x, y)
|
||||
|
||||
|
||||
@onnx_impl(torch.sym_not, trace_only=True)
|
||||
def sym_not(self: BOOL) -> BOOL:
|
||||
"""sym_not(SymBool self) -> SymBool"""
|
||||
|
Reference in New Issue
Block a user