mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Implement sym_not (#152111)
Implement onnx support for sym_not. Replaces https://github.com/pytorch/pytorch/pull/147472 Fix https://github.com/pytorch/pytorch/issues/136572 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152111 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
6120cc8ccd
commit
a811d3351b
@ -542,6 +542,21 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
# make sre the naming is working
|
||||
self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx")
|
||||
|
||||
def test_export_sym_not(self):
|
||||
class SymNotModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
comparison = x.shape[0] == x.shape[1]
|
||||
return torch.sym_not(comparison)
|
||||
|
||||
inputs = (torch.zeros((2, 2)),)
|
||||
dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},)
|
||||
onnx_program = self.export(SymNotModel(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=inputs)
|
||||
self.assertIn(
|
||||
"Not",
|
||||
[node.op_type for node in onnx_program.model.graph],
|
||||
)
|
||||
|
||||
def test_group_norm_opset_21(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -66,15 +66,19 @@ def assert_onnx_program(
|
||||
torch_module = exported_program.module()
|
||||
torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs))
|
||||
# ONNX outputs are always real, so we need to convert torch complex outputs to real representations
|
||||
torch_outputs = [
|
||||
torch.view_as_real(output) if torch.is_complex(output) else output
|
||||
for output in torch_outputs
|
||||
]
|
||||
torch_outputs_adapted = []
|
||||
for output in torch_outputs:
|
||||
if not isinstance(output, torch.Tensor):
|
||||
torch_outputs_adapted.append(torch.tensor(output))
|
||||
elif torch.is_complex(output):
|
||||
torch_outputs_adapted.append(torch.view_as_real(output))
|
||||
else:
|
||||
torch_outputs_adapted.append(output)
|
||||
onnx_outputs = program(*args, **kwargs)
|
||||
# TODO(justinchuby): Include output names in the error message
|
||||
torch.testing.assert_close(
|
||||
tuple(onnx_outputs),
|
||||
tuple(torch_outputs),
|
||||
tuple(torch_outputs_adapted),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=True,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = ["core", "hop", "nn", "symbolic"]
|
||||
__all__ = ["core", "hop", "nn", "symbolic", "symops"]
|
||||
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic, symops
|
||||
|
19
torch/onnx/_internal/exporter/_torchlib/ops/symops.py
Normal file
19
torch/onnx/_internal/exporter/_torchlib/ops/symops.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Implementation for torch.sym* ops."""
|
||||
|
||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||
# ruff: noqa: TCH001,TCH002
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from onnxscript.onnx_opset import opset18 as op
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import BOOL
|
||||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||
|
||||
|
||||
@onnx_impl(torch.sym_not, trace_only=True)
|
||||
def sym_not(self: BOOL) -> BOOL:
|
||||
"""sym_not(SymBool self) -> SymBool"""
|
||||
return op.Not(self)
|
Reference in New Issue
Block a user