mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Support sym_float (#153200)
Fixes #153115 Note: torch.sym_int is not supported in this PR because it's not appeared in exported program, instead, it's `torch.ops.aten.sym_size.int()`. ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s35, s16]"): # sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0); x = None return (sym_size_int_1,) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153200 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
da0b89bcbf
commit
90fde0dc09
@ -586,6 +586,23 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
)
|
||||
|
||||
def test_export_sym_float(self):
|
||||
class SymFloatModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
a = x.shape[0]
|
||||
return torch.sym_float(a)
|
||||
|
||||
inputs = (torch.zeros((2, 2)),)
|
||||
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
|
||||
onnx_program = self.export(
|
||||
SymFloatModel(), inputs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
self.assertIn(
|
||||
"Cast",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
)
|
||||
|
||||
def test_group_norm_opset_21(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -29,7 +29,7 @@ from onnxscript import (
|
||||
# NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them.
|
||||
# More detail can be found: https://pytorch.org/docs/stable/tensors.html
|
||||
|
||||
_TensorType = Union[
|
||||
TensorType = Union[
|
||||
BFLOAT16,
|
||||
BOOL,
|
||||
COMPLEX64,
|
||||
@ -56,11 +56,11 @@ RealType = Union[
|
||||
INT64,
|
||||
]
|
||||
|
||||
TTensor = TypeVar("TTensor", bound=_TensorType)
|
||||
TTensor = TypeVar("TTensor", bound=TensorType)
|
||||
# Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor
|
||||
# but do not constrain the type to be the same as the other inputs/outputs
|
||||
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
|
||||
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
|
||||
TTensor2 = TypeVar("TTensor2", bound=TensorType)
|
||||
TTensorOrString = TypeVar("TTensorOrString", bound=Union[TensorType, STRING])
|
||||
TFloat = TypeVar("TFloat", bound=_FloatType)
|
||||
TFloatOrUInt8 = TypeVar(
|
||||
"TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]
|
||||
|
@ -9,10 +9,22 @@ from __future__ import annotations
|
||||
from onnxscript.onnx_opset import opset18 as op
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL, IntType
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import (
|
||||
BOOL,
|
||||
FLOAT,
|
||||
INT64,
|
||||
IntType,
|
||||
TensorType,
|
||||
)
|
||||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||
|
||||
|
||||
@onnx_impl(torch.sym_float, trace_only=True)
|
||||
def sym_float(self: TensorType) -> FLOAT:
|
||||
"""sym_float(SymInt self) -> SymFloat"""
|
||||
return op.Cast(self, to=FLOAT.dtype)
|
||||
|
||||
|
||||
@onnx_impl(torch.sym_max, trace_only=True)
|
||||
def sym_max(x: IntType, y: IntType) -> IntType:
|
||||
"""sym_max(SymInt x, SymInt y) -> SymInt"""
|
||||
|
Reference in New Issue
Block a user