Files
pytorch/test/onnx/test_custom_ops.py
AllenTiTaiWang b782beb18e [ONNX] Expose OnnxRegistry publicly (#106140)
The official move of `OnnxRegistry` to `torch.onnx` allows it to become one of the parameters in `torch.onnx.ExportOption`. By incorporating `OnnxRegistry` in `torch.onnx.ExportOption`, users gain access to various functionalities, including the ability to register custom operators using `register_custom_op`, check whether an operator is supported using `is_registered_op`, and obtain symbolic functions that support specific operators using `get_functions`.

Additionally, `opset_version` is now exclusively available in `torch.onnx.OnnxRegistry` as it is removed from `torch.onnx.ExportOption`. The initialization of the registry with torchlib under the provided opset version ensures that the exporter uses the specified opset version as the primary version for exporting.

These changes encompass scenarios where users can:

1. Register an unsupported ATen operator with a custom implementation using onnx-script.
2. Override an existing symbolic function (onnx invariant).

NOTE: The custom registered function will be prioritized in onnx dispatcher, and if there are multiple custom ones, the one registered the last will be picked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106140
Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi
2023-08-04 20:46:03 +00:00

121 lines
3.9 KiB
Python

# Owner(s): ["module: onnx"]
import onnx_test_common
import pytorch_test_common
import torch
import torch.utils.cpp_extension
from torch.onnx import symbolic_helper
from torch.testing._internal import common_utils
class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase):
opset_version = 9
keep_initializers_as_inputs = False
onnx_shape_inference = True
def test_symbolic(self):
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, input, scalar):
ctx.save_for_backward(input)
return input.clamp(min=scalar)
@staticmethod
def symbolic(g, input, scalar):
return g.op("Clip", input, min_f=scalar)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.clip = MyClip.apply
def forward(self, x):
h = self.clip(x, 2)
return h
x = torch.randn(2, 3, 4, requires_grad=True)
model = MyModule()
onnx_test_common.run_model_test(self, model, input_args=(x,))
def test_register_op(self):
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, input, scalar):
ctx.save_for_backward(input)
return input.clamp(min=scalar)
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.clip = MyClip.apply
self.relu = MyRelu.apply
def forward(self, x):
h = self.clip(x, 2)
h = self.relu(h)
return h
def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs):
n = ctx.cur_node
name = kwargs["name"]
if name == "MyClip":
return g.op("Clip", args[0], min_f=args[1], outputs=n.outputsSize())
elif name == "MyRelu":
return g.op("Relu", args[0], outputs=n.outputsSize())
else:
return symbolic_helper._unimplemented(
"prim::PythonOp", "unknown node kind: " + name
)
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1)
x = torch.randn(2, 3, 4, requires_grad=True)
model = MyModule()
onnx_test_common.run_model_test(self, model, input_args=(x,))
class TestExportAsContribOps(pytorch_test_common.ExportTestCase):
opset_version = 14
keep_initializers_as_inputs = False
onnx_shape_inference = True
def test_contrib_op_with_loop(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU(approximate="none")
def forward(self, x):
res = []
res2 = []
for i in range(x.size(0)):
if len(res) > 0:
res2.append(res[0])
else:
res2.append(self.gelu(x[0]))
res.append(x[0])
return torch.stack(res), torch.stack(res2)
def symbolic_custom_gelu(g, input, approximate):
return g.op("com.microsoft::Gelu", input).setType(input.type())
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("::gelu", symbolic_custom_gelu, 1)
x = torch.randn(3, 3, 4, requires_grad=True)
model = torch.jit.script(M())
onnx_test_common.run_model_test(self, model, input_args=(x,))
if __name__ == "__main__":
common_utils.run_tests()