mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved. - Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly - Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused @albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and `torch/autograd/_functions/utils.py`? Thanks! ## BC Breaking - **Deprecated members in `torch.onnx.verification` are removed** Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323 Approved by: https://github.com/titaiwangms, https://github.com/angelayi
134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
|
|
|
|
from typing import Sequence # noqa: UP035
|
|
|
|
import onnx_test_common
|
|
import onnxscript
|
|
from onnxscript.onnx_types import FLOAT
|
|
|
|
import torch
|
|
from torch.onnx._internal.torchscript_exporter import jit_utils
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
|
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
|
|
# opset version is
|
|
# 1. local function is supported after opset 15
|
|
# 2. onnx-script requires users to determine opset in local function
|
|
opset_version = 15
|
|
|
|
def test_selu_from_onnxscript_example(self):
|
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
model = torch.nn.SELU()
|
|
|
|
from onnxscript.onnx_opset import opset15 as op
|
|
|
|
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
|
|
|
|
@onnxscript.script(custom_opset)
|
|
def Selu(
|
|
X,
|
|
):
|
|
# default value is not supported by onnxscript
|
|
alpha = 1.67326 # auto wrapped as Constants
|
|
gamma = 1.0507
|
|
alphaX = op.CastLike(alpha, X)
|
|
gammaX = op.CastLike(gamma, X)
|
|
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
|
|
pos = gammaX * X
|
|
zero = op.CastLike(0, X)
|
|
return op.Where(X <= zero, neg, pos)
|
|
|
|
def custom_selu(g: jit_utils.GraphContext, X):
|
|
return g.onnxscript_op(Selu, X).setType(X.type())
|
|
|
|
torch.onnx.register_custom_op_symbolic(
|
|
symbolic_name="aten::selu",
|
|
symbolic_fn=custom_selu,
|
|
opset_version=self.opset_version,
|
|
)
|
|
self.run_test(model, x)
|
|
|
|
def test_layer_norm(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
z = torch.randn(2, 3)
|
|
|
|
class N(torch.nn.Module):
|
|
def __init__(self, prob):
|
|
super().__init__()
|
|
self.dropout = torch.nn.Dropout(prob)
|
|
|
|
def forward(self, x):
|
|
return self.dropout(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, num_layers):
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
self.lns = torch.nn.ModuleList(
|
|
[torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
|
|
)
|
|
self.celu1 = torch.nn.CELU(1.0)
|
|
self.celu2 = torch.nn.CELU(2.0)
|
|
self.dropout = N(0.5)
|
|
|
|
def forward(self, x, y, z):
|
|
res1 = self.celu1(x)
|
|
res2 = self.celu2(y)
|
|
for ln in self.lns:
|
|
z = ln(z)
|
|
return res1 + res2, self.dropout(z)
|
|
|
|
model = M(3)
|
|
|
|
from onnxscript.onnx_opset import opset15 as op
|
|
|
|
custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1)
|
|
|
|
@onnxscript.script(custom_opset)
|
|
def layer_norm(
|
|
X,
|
|
axes: Sequence[int],
|
|
weight: FLOAT[...],
|
|
bias: FLOAT[...],
|
|
eps: float,
|
|
):
|
|
mean = op.ReduceMean(X, axes=axes)
|
|
D = X - mean # op.Sub(X, mean)
|
|
DD = D * D # op.Mul(D, D)
|
|
var = op.ReduceMean(DD, axes=axes)
|
|
vareps = var + eps # op.Add(var, eps)
|
|
stddev = op.Sqrt(vareps)
|
|
invstddev = op.Reciprocal(stddev)
|
|
normalized = D * invstddev # op.Mul(D, invstddev)
|
|
normalizedw = op.CastLike(
|
|
normalized, weight
|
|
) # Type issue if missing this Op
|
|
normalizedscaled = normalizedw * weight # op.Mul(normalized, weight)
|
|
return normalizedscaled + bias
|
|
|
|
@torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
|
|
def custom_layer_norm(
|
|
g, input, normalized_shape, weight, bias, eps, cudnn_enable
|
|
):
|
|
# comprehension is not supported by onnxscript
|
|
axes = [-i for i in range(len(normalized_shape), 0, -1)]
|
|
return g.onnxscript_op(
|
|
layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
|
|
).setType(input.type())
|
|
|
|
torch.onnx.register_custom_op_symbolic(
|
|
symbolic_name="aten::layer_norm",
|
|
symbolic_fn=custom_layer_norm,
|
|
opset_version=self.opset_version,
|
|
)
|
|
|
|
self.run_test(model, (x, y, z))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|