mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[ONNX] Add onnx::Gelu support for version 20 (#128773)
Fixes https://github.com/pytorch/pytorch/issues/128772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128773 Approved by: https://github.com/justinchuby
This commit is contained in:
@ -1358,6 +1358,8 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
iter = graph.nodes()
|
||||
self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
|
||||
|
||||
# gelu is exported as onnx::Gelu for opset >= 20
|
||||
@skipIfUnsupportedMaxOpsetVersion(19)
|
||||
def test_custom_opsets_gelu(self):
|
||||
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
|
||||
|
||||
@ -1382,6 +1384,8 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
|
||||
self.assertEqual(graph.opset_import[1].version, 1)
|
||||
|
||||
# gelu is exported as onnx::Gelu for opset >= 20
|
||||
@skipIfUnsupportedMaxOpsetVersion(19)
|
||||
def test_register_aten_custom_op_symbolic(self):
|
||||
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)
|
||||
|
||||
|
@ -32,7 +32,7 @@ from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__ = ["_grid_sampler", "_affine_grid_generator"]
|
||||
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
|
||||
|
||||
|
||||
def convert_grid_sample_mode(mode_s):
|
||||
@ -84,3 +84,10 @@ def _affine_grid_generator(
|
||||
size,
|
||||
align_corners_i=int(align_corners),
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::gelu")
|
||||
@symbolic_helper.parse_args("v", "s")
|
||||
@_beartype.beartype
|
||||
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
|
||||
return g.op("Gelu", self, approximate_s=approximate)
|
||||
|
Reference in New Issue
Block a user