mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Set dynamo=True and enable fallback. 1. Implemented the compatible behavior where BytesIO objects as `f` is accepted 2. Update tests to explicitly set dynamo=False #151693 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159646 Approved by: https://github.com/titaiwangms
		
			
				
	
	
		
			180 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			180 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: onnx"]
 | |
| 
 | |
| """Test the support on onnxscript in PyTorch-ONNX converter."""
 | |
| 
 | |
| import io
 | |
| 
 | |
| import onnx
 | |
| 
 | |
| import onnxscript
 | |
| from onnxscript.onnx_types import FLOAT
 | |
| 
 | |
| import torch
 | |
| from torch.onnx._internal.torchscript_exporter import jit_utils
 | |
| from torch.testing._internal import common_utils
 | |
| 
 | |
| 
 | |
| class TestONNXScriptExport(common_utils.TestCase):
 | |
|     # opset version is
 | |
|     # 1. local function is supported after opset 15
 | |
|     # 2. onnx-script requires users to determine opset in local function
 | |
|     opset_version = 15
 | |
| 
 | |
|     def test_onnxscript_registration_with_multiple_models(self):
 | |
|         from onnxscript.onnx_opset import opset15 as op
 | |
| 
 | |
|         # 1. Register Selu onnxscript function as custom Op
 | |
|         custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
 | |
| 
 | |
|         @onnxscript.script(custom_opset)
 | |
|         def Selu(X):
 | |
|             # default value is not supported by onnxscript
 | |
|             alpha = 1.67326  # auto wrapped as Constants
 | |
|             gamma = 1.0507
 | |
|             alphaX = op.CastLike(alpha, X)
 | |
|             gammaX = op.CastLike(gamma, X)
 | |
|             neg = gammaX * (alphaX * op.Exp(X) - alphaX)
 | |
|             pos = gammaX * X
 | |
|             zero = op.CastLike(0, X)
 | |
|             return op.Where(X <= zero, neg, pos)
 | |
| 
 | |
|         def custom_selu(g: jit_utils.GraphContext, X):
 | |
|             return g.onnxscript_op(Selu, X).setType(X.type())
 | |
| 
 | |
|         torch.onnx.register_custom_op_symbolic(
 | |
|             symbolic_name="aten::selu",
 | |
|             symbolic_fn=custom_selu,
 | |
|             opset_version=self.opset_version,
 | |
|         )
 | |
| 
 | |
|         # 2. Register layer_norm onnxscript function as custom Op
 | |
|         @onnxscript.script(custom_opset)
 | |
|         def layer_norm(
 | |
|             X, axes: list[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
 | |
|         ):
 | |
|             mean = op.ReduceMean(X, axes=axes)
 | |
|             D = X - mean  # op.Sub(X, mean)
 | |
|             DD = D * D  # op.Mul(D, D)
 | |
|             var = op.ReduceMean(DD, axes=axes)
 | |
|             vareps = var + eps  # op.Add(var, eps)
 | |
|             stddev = op.Sqrt(vareps)
 | |
|             invstddev = op.Reciprocal(stddev)
 | |
|             normalized = D * invstddev  # op.Mul(D, invstddev)
 | |
|             normalizedw = op.CastLike(
 | |
|                 normalized, weight
 | |
|             )  # Type issue if missing this Op
 | |
|             normalizedscaled = normalizedw * weight  # op.Mul(normalized, weight)
 | |
|             return normalizedscaled + bias
 | |
| 
 | |
|         @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
 | |
|         def custom_layer_norm(
 | |
|             g, input, normalized_shape, weight, bias, eps, cudnn_enable
 | |
|         ):
 | |
|             # comprehension is not supported by onnxscript
 | |
|             axes = [-i for i in range(len(normalized_shape), 0, -1)]
 | |
|             return g.onnxscript_op(
 | |
|                 layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
 | |
|             ).setType(input.type())
 | |
| 
 | |
|         torch.onnx.register_custom_op_symbolic(
 | |
|             symbolic_name="aten::layer_norm",
 | |
|             symbolic_fn=custom_layer_norm,
 | |
|             opset_version=self.opset_version,
 | |
|         )
 | |
| 
 | |
|         # 3. export two models
 | |
|         x = torch.randn(1, 2, 3, 4, requires_grad=True)
 | |
|         model_selu = torch.nn.SELU()
 | |
|         selu_onnx = io.BytesIO()
 | |
|         torch.onnx.export(
 | |
|             model_selu, x, selu_onnx, opset_version=self.opset_version, dynamo=False
 | |
|         )
 | |
| 
 | |
|         N, C = 3, 4
 | |
|         y = torch.randn(N, C)
 | |
|         model_layer_norm = torch.nn.LayerNorm(C)
 | |
|         layer_norm_onnx = io.BytesIO()
 | |
|         torch.onnx.export(
 | |
|             model_layer_norm,
 | |
|             y,
 | |
|             layer_norm_onnx,
 | |
|             opset_version=self.opset_version,
 | |
|             dynamo=False,
 | |
|         )
 | |
| 
 | |
|         # 4. test on models
 | |
|         selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue()))
 | |
|         layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue()))
 | |
| 
 | |
|         self.assertEqual(len(selu_proto.functions), 1)
 | |
|         self.assertEqual(len(layer_norm_proto.functions), 1)
 | |
|         self.assertEqual(selu_proto.functions[0].name, "Selu")
 | |
|         self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm")
 | |
| 
 | |
|     def test_loop_registration(self):
 | |
|         # Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py,
 | |
|         # which has recursive logic to go through every nodes with subgraph in model proto
 | |
|         class NestedLoopsModel(torch.jit.ScriptModule):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.selu = torch.nn.SELU()
 | |
| 
 | |
|             @torch.jit.script_method
 | |
|             def forward(self, x):
 | |
|                 y = x
 | |
|                 for i in range(x.size(3)):
 | |
|                     if i == 0:
 | |
|                         y = self.selu(x)
 | |
|                     else:
 | |
|                         y += i
 | |
|                 return y
 | |
| 
 | |
|         model = NestedLoopsModel()
 | |
|         inputs = torch.zeros(1, 2, 3, 4)
 | |
| 
 | |
|         from onnxscript.onnx_opset import opset15 as op
 | |
| 
 | |
|         custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2)
 | |
| 
 | |
|         @onnxscript.script(custom_opset)
 | |
|         def Selu(X):
 | |
|             alpha = 1.6732632423543772848170429916717
 | |
|             gamma = 1.0507009873554804934193349852946
 | |
|             alphaX = op.CastLike(alpha, X)
 | |
|             gammaX = op.CastLike(gamma, X)
 | |
|             neg = gammaX * (alphaX * op.Exp(X) - alphaX)
 | |
|             pos = gammaX * X
 | |
|             zero = op.CastLike(0, X)
 | |
|             return op.Where(X <= zero, neg, pos)
 | |
| 
 | |
|         def custom_selu(g, X):
 | |
|             # domain of the Op should be aligned with onnx-script
 | |
|             # setType API is required for custom Op to support
 | |
|             # torchscript shape type inference
 | |
|             print("custom_selu is used!")
 | |
|             return g.onnxscript_op(Selu, X).setType(X.type())
 | |
| 
 | |
|         torch.onnx.register_custom_op_symbolic(
 | |
|             symbolic_name="aten::selu",
 | |
|             symbolic_fn=custom_selu,
 | |
|             opset_version=15,
 | |
|         )
 | |
| 
 | |
|         saved_model = io.BytesIO()
 | |
|         torch.onnx.export(
 | |
|             torch.jit.script(model),
 | |
|             inputs,
 | |
|             f=saved_model,
 | |
|             opset_version=15,
 | |
|             dynamo=False,
 | |
|         )
 | |
|         loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue()))
 | |
|         self.assertEqual(len(loop_selu_proto.functions), 1)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     raise RuntimeError(
 | |
|         "This test is not currently used and should be "
 | |
|         "enabled in discover_tests.py if required."
 | |
|     )
 |