mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved. - Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly - Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused @albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and `torch/autograd/_functions/utils.py`? Thanks! ## BC Breaking - **Deprecated members in `torch.onnx.verification` are removed** Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323 Approved by: https://github.com/titaiwangms, https://github.com/angelayi
		
			
				
	
	
		
			134 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			134 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: onnx"]
 | 
						|
 | 
						|
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
 | 
						|
 | 
						|
from typing import Sequence  # noqa: UP035
 | 
						|
 | 
						|
import onnx_test_common
 | 
						|
import onnxscript
 | 
						|
from onnxscript.onnx_types import FLOAT
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.onnx._internal.torchscript_exporter import jit_utils
 | 
						|
from torch.testing._internal import common_utils
 | 
						|
 | 
						|
 | 
						|
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
 | 
						|
    # opset version is
 | 
						|
    # 1. local function is supported after opset 15
 | 
						|
    # 2. onnx-script requires users to determine opset in local function
 | 
						|
    opset_version = 15
 | 
						|
 | 
						|
    def test_selu_from_onnxscript_example(self):
 | 
						|
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
 | 
						|
        model = torch.nn.SELU()
 | 
						|
 | 
						|
        from onnxscript.onnx_opset import opset15 as op
 | 
						|
 | 
						|
        custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
 | 
						|
 | 
						|
        @onnxscript.script(custom_opset)
 | 
						|
        def Selu(
 | 
						|
            X,
 | 
						|
        ):
 | 
						|
            # default value is not supported by onnxscript
 | 
						|
            alpha = 1.67326  # auto wrapped as Constants
 | 
						|
            gamma = 1.0507
 | 
						|
            alphaX = op.CastLike(alpha, X)
 | 
						|
            gammaX = op.CastLike(gamma, X)
 | 
						|
            neg = gammaX * (alphaX * op.Exp(X) - alphaX)
 | 
						|
            pos = gammaX * X
 | 
						|
            zero = op.CastLike(0, X)
 | 
						|
            return op.Where(X <= zero, neg, pos)
 | 
						|
 | 
						|
        def custom_selu(g: jit_utils.GraphContext, X):
 | 
						|
            return g.onnxscript_op(Selu, X).setType(X.type())
 | 
						|
 | 
						|
        torch.onnx.register_custom_op_symbolic(
 | 
						|
            symbolic_name="aten::selu",
 | 
						|
            symbolic_fn=custom_selu,
 | 
						|
            opset_version=self.opset_version,
 | 
						|
        )
 | 
						|
        self.run_test(model, x)
 | 
						|
 | 
						|
    def test_layer_norm(self):
 | 
						|
        x = torch.randn(2, 3)
 | 
						|
        y = torch.randn(2, 3)
 | 
						|
        z = torch.randn(2, 3)
 | 
						|
 | 
						|
        class N(torch.nn.Module):
 | 
						|
            def __init__(self, prob):
 | 
						|
                super().__init__()
 | 
						|
                self.dropout = torch.nn.Dropout(prob)
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                return self.dropout(x)
 | 
						|
 | 
						|
        class M(torch.nn.Module):
 | 
						|
            def __init__(self, num_layers):
 | 
						|
                super().__init__()
 | 
						|
                self.num_layers = num_layers
 | 
						|
                self.lns = torch.nn.ModuleList(
 | 
						|
                    [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
 | 
						|
                )
 | 
						|
                self.celu1 = torch.nn.CELU(1.0)
 | 
						|
                self.celu2 = torch.nn.CELU(2.0)
 | 
						|
                self.dropout = N(0.5)
 | 
						|
 | 
						|
            def forward(self, x, y, z):
 | 
						|
                res1 = self.celu1(x)
 | 
						|
                res2 = self.celu2(y)
 | 
						|
                for ln in self.lns:
 | 
						|
                    z = ln(z)
 | 
						|
                return res1 + res2, self.dropout(z)
 | 
						|
 | 
						|
        model = M(3)
 | 
						|
 | 
						|
        from onnxscript.onnx_opset import opset15 as op
 | 
						|
 | 
						|
        custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1)
 | 
						|
 | 
						|
        @onnxscript.script(custom_opset)
 | 
						|
        def layer_norm(
 | 
						|
            X,
 | 
						|
            axes: Sequence[int],
 | 
						|
            weight: FLOAT[...],
 | 
						|
            bias: FLOAT[...],
 | 
						|
            eps: float,
 | 
						|
        ):
 | 
						|
            mean = op.ReduceMean(X, axes=axes)
 | 
						|
            D = X - mean  # op.Sub(X, mean)
 | 
						|
            DD = D * D  # op.Mul(D, D)
 | 
						|
            var = op.ReduceMean(DD, axes=axes)
 | 
						|
            vareps = var + eps  # op.Add(var, eps)
 | 
						|
            stddev = op.Sqrt(vareps)
 | 
						|
            invstddev = op.Reciprocal(stddev)
 | 
						|
            normalized = D * invstddev  # op.Mul(D, invstddev)
 | 
						|
            normalizedw = op.CastLike(
 | 
						|
                normalized, weight
 | 
						|
            )  # Type issue if missing this Op
 | 
						|
            normalizedscaled = normalizedw * weight  # op.Mul(normalized, weight)
 | 
						|
            return normalizedscaled + bias
 | 
						|
 | 
						|
        @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
 | 
						|
        def custom_layer_norm(
 | 
						|
            g, input, normalized_shape, weight, bias, eps, cudnn_enable
 | 
						|
        ):
 | 
						|
            # comprehension is not supported by onnxscript
 | 
						|
            axes = [-i for i in range(len(normalized_shape), 0, -1)]
 | 
						|
            return g.onnxscript_op(
 | 
						|
                layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
 | 
						|
            ).setType(input.type())
 | 
						|
 | 
						|
        torch.onnx.register_custom_op_symbolic(
 | 
						|
            symbolic_name="aten::layer_norm",
 | 
						|
            symbolic_fn=custom_layer_norm,
 | 
						|
            opset_version=self.opset_version,
 | 
						|
        )
 | 
						|
 | 
						|
        self.run_test(model, (x, y, z))
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    common_utils.run_tests()
 |