[ONNX] Support sym_float (#153200)

Fixes #153115

Note: torch.sym_int is not supported in this PR because it's not appeared in exported program, instead, it's `torch.ops.aten.sym_size.int()`.

```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]"):
             #
            sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0);  x = None
            return (sym_size_int_1,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153200
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
Ti-Tai Wang
2025-05-09 19:10:13 +00:00
committed by PyTorch MergeBot
parent da0b89bcbf
commit 90fde0dc09
3 changed files with 34 additions and 5 deletions

View File

@ -586,6 +586,23 @@ class DynamoExporterTest(common_utils.TestCase):
[node.op_type for node in onnx_program.model.graph],
)
def test_export_sym_float(self):
class SymFloatModel(torch.nn.Module):
def forward(self, x):
a = x.shape[0]
return torch.sym_float(a)
inputs = (torch.zeros((2, 2)),)
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
onnx_program = self.export(
SymFloatModel(), inputs, dynamic_shapes=dynamic_shapes
)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
self.assertIn(
"Cast",
[node.op_type for node in onnx_program.model.graph],
)
def test_group_norm_opset_21(self):
class Model(torch.nn.Module):
def forward(self, x):

View File

@ -29,7 +29,7 @@ from onnxscript import (
# NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them.
# More detail can be found: https://pytorch.org/docs/stable/tensors.html
_TensorType = Union[
TensorType = Union[
BFLOAT16,
BOOL,
COMPLEX64,
@ -56,11 +56,11 @@ RealType = Union[
INT64,
]
TTensor = TypeVar("TTensor", bound=_TensorType)
TTensor = TypeVar("TTensor", bound=TensorType)
# Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor
# but do not constrain the type to be the same as the other inputs/outputs
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
TTensor2 = TypeVar("TTensor2", bound=TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[TensorType, STRING])
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrUInt8 = TypeVar(
"TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]

View File

@ -9,10 +9,22 @@ from __future__ import annotations
from onnxscript.onnx_opset import opset18 as op
import torch
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL, IntType
from torch.onnx._internal.exporter._torchlib._tensor_typing import (
BOOL,
FLOAT,
INT64,
IntType,
TensorType,
)
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
@onnx_impl(torch.sym_float, trace_only=True)
def sym_float(self: TensorType) -> FLOAT:
"""sym_float(SymInt self) -> SymFloat"""
return op.Cast(self, to=FLOAT.dtype)
@onnx_impl(torch.sym_max, trace_only=True)
def sym_max(x: IntType, y: IntType) -> IntType:
"""sym_max(SymInt x, SymInt y) -> SymInt"""