[ONNX] Implement sym_not (#152111)

Implement onnx support for sym_not. Replaces https://github.com/pytorch/pytorch/pull/147472

Fix https://github.com/pytorch/pytorch/issues/136572
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152111
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2025-04-25 07:50:33 +00:00
committed by PyTorch MergeBot
parent 6120cc8ccd
commit a811d3351b
4 changed files with 45 additions and 7 deletions

View File

@ -542,6 +542,21 @@ class DynamoExporterTest(common_utils.TestCase):
# make sre the naming is working
self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx")
def test_export_sym_not(self):
class SymNotModel(torch.nn.Module):
def forward(self, x):
comparison = x.shape[0] == x.shape[1]
return torch.sym_not(comparison)
inputs = (torch.zeros((2, 2)),)
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
onnx_program = self.export(SymNotModel(), inputs, dynamic_shapes=dynamic_shapes)
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
self.assertIn(
"Not",
[node.op_type for node in onnx_program.model.graph],
)
def test_group_norm_opset_21(self):
class Model(torch.nn.Module):
def forward(self, x):

View File

@ -66,15 +66,19 @@ def assert_onnx_program(
torch_module = exported_program.module()
torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs))
# ONNX outputs are always real, so we need to convert torch complex outputs to real representations
torch_outputs = [
torch.view_as_real(output) if torch.is_complex(output) else output
for output in torch_outputs
]
torch_outputs_adapted = []
for output in torch_outputs:
if not isinstance(output, torch.Tensor):
torch_outputs_adapted.append(torch.tensor(output))
elif torch.is_complex(output):
torch_outputs_adapted.append(torch.view_as_real(output))
else:
torch_outputs_adapted.append(output)
onnx_outputs = program(*args, **kwargs)
# TODO(justinchuby): Include output names in the error message
torch.testing.assert_close(
tuple(onnx_outputs),
tuple(torch_outputs),
tuple(torch_outputs_adapted),
rtol=rtol,
atol=atol,
equal_nan=True,

View File

@ -1,6 +1,6 @@
from __future__ import annotations
__all__ = ["core", "hop", "nn", "symbolic"]
__all__ = ["core", "hop", "nn", "symbolic", "symops"]
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic, symops

View File

@ -0,0 +1,19 @@
"""Implementation for torch.sym* ops."""
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
# ruff: noqa: TCH001,TCH002
# flake8: noqa
from __future__ import annotations
from onnxscript.onnx_opset import opset18 as op
import torch
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
@onnx_impl(torch.sym_not, trace_only=True)
def sym_not(self: BOOL) -> BOOL:
"""sym_not(SymBool self) -> SymBool"""
return op.Not(self)