Files
pytorch/test/onnx/test_pytorch_onnx_no_runtime.py
Justin Chu 524b78d4f6 [ONNX] Refactor torchscript based exporter (#161323)
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved.

- Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly
- Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused

@albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and
`torch/autograd/_functions/utils.py`? Thanks!

## BC Breaking
- **Deprecated members in `torch.onnx.verification` are removed**

Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323
Approved by: https://github.com/titaiwangms, https://github.com/angelayi
2025-09-02 16:10:30 +00:00

1226 lines
42 KiB
Python

# Owner(s): ["module: onnx"]
"""Tests for onnx export that don't run the exported model."""
from __future__ import annotations
import contextlib
import io
import itertools
import unittest
import unittest.mock
import warnings
from typing import Callable, Optional, TYPE_CHECKING, Union
import numpy as np
import onnx
import onnx.numpy_helper
import pytorch_test_common
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.onnx import symbolic_helper, utils
from torch.onnx._internal.torchscript_exporter import registration
from torch.testing._internal import common_quantization, common_utils, jit_utils
if TYPE_CHECKING:
from collections.abc import Iterable
def export_to_onnx(
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
input: Union[torch.Tensor, tuple[torch.Tensor]],
custom_ops: Optional[
Iterable[Union[contextlib.AbstractContextManager, contextlib.ContextDecorator]]
] = None,
mocks: Optional[Iterable] = None,
operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX,
opset_version: int = 17,
**torch_onnx_export_kwargs,
) -> onnx.ModelProto:
"""Exports `model(input)` to ONNX and returns it.
Custom operators and/or unittest patches can be used help reproducing specific behaviors.
Args:
model: model to export
input: model input with same format as `torch.onnx.export(..,args,...)`
custom_ops: list of custom operators to use during export
mocks: list of mocks to use during export
operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)`
opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)`
torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments
Returns:
A valid ONNX model (`onnx.ModelProto`)
"""
custom_ops = custom_ops or []
mocks = mocks or []
with contextlib.ExitStack() as stack:
for ctx in itertools.chain(custom_ops, mocks):
stack.enter_context(ctx)
f = io.BytesIO()
torch.onnx.export(
model,
input,
f,
operator_export_type=operator_export_type,
opset_version=opset_version,
**torch_onnx_export_kwargs,
)
# Validate ONNX graph before returning it
onnx_model = onnx.load_from_string(f.getvalue())
onnx.checker.check_model(onnx_model)
return onnx_model
@common_utils.instantiate_parametrized_tests
class TestONNXExport(pytorch_test_common.ExportTestCase):
def test_fuse_addmm(self):
class AddmmModel(torch.nn.Module):
def forward(self, x):
return torch.mm(x, x) + x
x = torch.ones(3, 3)
f = io.BytesIO()
torch.onnx.export(AddmmModel(), x, f)
def test_onnx_transpose_incomplete_tensor_type(self):
# Smoke test to get us into the state where we are attempting to export
# a transpose op, where the input is a TensorType without size information.
# This would previously not work, since we would
# take the size of the input and use the length of its sizes as the
# number of dimensions in the permutation.
class Foo(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.contiguous().transpose(0, 1).sum()
class TraceMe(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Foo()
def forward(self, x):
return self.foo(x)
tm = TraceMe()
tm = torch.jit.trace(tm, torch.rand(3, 4))
f = io.BytesIO()
torch.onnx.export(tm, (torch.rand(3, 4),), f)
def test_export_tensoroption_to(self):
def foo(x):
return x[0].detach().clone().cpu() + x
traced = torch.jit.trace(foo, (torch.rand([2])))
f = io.BytesIO()
torch.onnx.export(traced, (torch.rand([2]),), f)
def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
y = x - x # noqa: F841
return x + x
mte = ModuleToExport()
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
@common_utils.suppress_warnings
def test_onnx_export_func_with_warnings(self):
@torch.jit.script
def func_with_warning(inp):
return torch.nn.functional.sigmoid(inp) # triggers a deprecation warning
class WarningTest(torch.nn.Module):
def forward(self, x):
return func_with_warning(x)
# no exception
f = io.BytesIO()
torch.onnx.export(WarningTest(), torch.randn(42), f)
def test_onnx_export_script_python_fail(self):
class PythonModule(torch.jit.ScriptModule):
@torch.jit.ignore
def forward(self, x):
return torch.neg(x)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.mod = PythonModule()
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return y + y
mte = ModuleToExport()
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
def test_onnx_export_script_inline_trace(self):
class ModuleToInline(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return y + y
mte = ModuleToExport()
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
def test_onnx_export_script_inline_script(self):
class ModuleToInline(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.neg(x)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.mod = ModuleToInline()
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return y + y
mte = ModuleToExport()
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
def test_onnx_export_script_module_loop(self):
class ModuleToExport(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
# test if we support end to end onnx export on loop and
# nested loops with and without loop index
for _ in range(5):
for i in range(3):
x = x + i
return x
mte = ModuleToExport()
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
@common_utils.suppress_warnings
def test_onnx_export_script_truediv(self):
class ModuleToExport(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
z = x.size(0) / 2
return x + z
mte = ModuleToExport()
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3, dtype=torch.float),), f)
def test_onnx_export_script_non_alpha_add_sub(self):
class ModuleToExport(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
bs = x.size(0) + 1
return bs - 1
mte = ModuleToExport()
f = io.BytesIO()
torch.onnx.export(mte, (torch.rand(3, 4),), f)
def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
if bool(torch.sum(x) > 0):
x = torch.neg(x)
return x
mte = ModuleToExport()
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
def test_onnx_export_script_inline_params(self):
class ModuleToInline(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.m = torch.nn.Parameter(torch.ones(3, 3))
self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
@torch.jit.script_method
def forward(self, x):
return torch.mm(x, self.m)
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.mod = ModuleToInline()
self.param = torch.nn.Parameter(torch.ones(3, 4))
@torch.jit.script_method
def forward(self, x):
y = self.mod(x)
return torch.mm(y, self.param)
mte = ModuleToExport()
result = mte(torch.zeros(2, 3))
reference = torch.mm(
torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4)
)
self.assertEqual(result, reference)
f = io.BytesIO()
torch.onnx.export(mte, (torch.ones(2, 3),), f)
def test_onnx_export_speculate(self):
class Foo(torch.jit.ScriptModule):
def __init__(self, m):
super().__init__()
self.m = m
@torch.jit.script_method
def forward(self, x):
x += x
# because we are testing if we emit `if` statement correctly
# we cannot use `True` as the condition. Constant prop
# would remove the `if` statements.
c = torch.sum(x) > 4
if bool(c):
if bool(c):
y = self.m(x)
else:
y = self.m(x)
else:
y = self.m(x)
return y
linear = torch.jit.trace(
torch.nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float)
)
@torch.jit.script
def transpose(x):
return x.t()
f1 = Foo(transpose)
f2 = Foo(linear)
f = io.BytesIO()
torch.onnx.export(f1, (torch.ones(1, 10, dtype=torch.float),), f)
f = io.BytesIO()
torch.onnx.export(f2, (torch.ones(1, 10, dtype=torch.float),), f)
def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
def forward(self, x):
import torch.onnx.operators
x = x.repeat(5, 1, 1)
shape = torch.onnx.operators.shape_as_tensor(x)
reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
return reshaped
foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
f = io.BytesIO()
torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f)
def test_export_dynamic_slice(self):
class DynamicSliceExportMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
retval = x[0]
for i in range(x.size(1)):
retval += torch.sum(x[0:i], dim=0)
return retval
input = torch.rand(3, 4, 5)
f = io.BytesIO()
torch.onnx.export(DynamicSliceExportMod(), (input,), f, opset_version=10)
def test_export_dict(self):
class DictModule(torch.nn.Module):
def forward(self, x_in: torch.Tensor) -> dict[str, torch.Tensor]:
return {"test_key_out": x_in}
x_in = torch.tensor(1)
mod = DictModule()
mod.train(False)
f = io.BytesIO()
torch.onnx.export(mod, (x_in,), f)
with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported"):
f = io.BytesIO()
torch.onnx.export(torch.jit.script(mod), (x_in,), f)
def test_source_range_propagation(self):
class ExpandingModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# Will be expanded during ONNX export
self.ln = torch.nn.LayerNorm([1])
def forward(self, input):
return self.ln(input)
mod = ExpandingModule()
graph, _, _ = utils._model_to_graph(
mod,
(torch.zeros(1),),
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
)
# Ensure that every node in the graph has a valid source range
for node in graph.nodes():
self.assertTrue(node.sourceRange())
def test_clip_aten_fallback_due_exception(self):
def bad_clamp(g, self, min, max):
return symbolic_helper._onnx_unsupported("Bad boy!")
class MyClip(torch.nn.Module):
def forward(self, x):
return torch.clamp(x, min=-0.5, max=0.5)
onnx_model = export_to_onnx(
MyClip(),
torch.randn(3, 4, requires_grad=True),
custom_ops=[common_utils.custom_op("aten::clamp", bad_clamp, 17)],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
self.assertAtenOp(onnx_model, "clamp", "Tensor")
def test_clip_aten_fallback_explicit_request(self):
class MyClip(torch.nn.Module):
def forward(self, x):
return torch.clamp(x, min=-0.5, max=0.5)
# Copy of mocked method must be saved to prevent
# max recursion depth while trying to run original instance method
original_get_function_group = registration.registry.get_function_group
def break_is_registered_op_api(name):
fake_missing_symbolics = {"aten::clamp"}
if name in fake_missing_symbolics:
return None
return original_get_function_group(name)
# Force missing symbolic for well-known op using a mock
onnx_model = export_to_onnx(
MyClip(),
torch.randn(3, 4, requires_grad=True),
mocks=[
unittest.mock.patch(
"torch.onnx._internal.torchscript_exporter.registration.registry.get_function_group",
side_effect=break_is_registered_op_api,
)
],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
self.assertAtenOp(onnx_model, "clamp", "Tensor")
def _helper_test_to_(self, cast_fn: Callable[[torch.Tensor], torch.Tensor]):
"""Helper to test aten::to(device) variants.
`cast_fn` is converted into a `torch.jit.script`. It wraps `aten::to`
during export to preventing the devices to be hard-coded.
Needed by detectron2 after https://github.com/facebookresearch/detectron2/pull/4132/
"""
cast_fn = torch.jit.script(cast_fn)
onnx_model = export_to_onnx(cast_fn, torch.zeros([1, 3, 32, 32]))
for n in onnx_model.graph.node:
self.assertNotEqual(n.op_type, "To")
self.assertNotEqual(n.op_type, "Cast")
def test_to__cpu_string(self):
def cast_cpu_string(src: torch.Tensor) -> torch.Tensor:
return src.to("cpu")
self._helper_test_to_(cast_cpu_string)
def test_to__device_cpu_string(self):
def cast_device_cpu_string(src: torch.Tensor) -> torch.Tensor:
return src.to(device="cpu")
self._helper_test_to_(cast_device_cpu_string)
def test_script_custom_class_error(self):
class BoxCoder:
def __init__(self, bbox_xform_clip: float) -> None:
self.bbox_xform_clip = bbox_xform_clip
def decode(self, rel_codes: Tensor, boxes: list[Tensor]) -> Tensor:
boxes = torch.cat(boxes, dim=0)
pred_ctr_x = (
torch.clamp(rel_codes[:, 0::4], max=self.bbox_xform_clip)
* boxes[:, 2]
)
return pred_ctr_x
class MyModule(torch.nn.Module):
__annotations__ = {
"box_coder": BoxCoder,
}
def __init__(self) -> None:
super().__init__()
self.box_coder = BoxCoder(1.4)
def forward(self, box_regression: Tensor, proposals: list[Tensor]):
return self.box_coder.decode(box_regression, proposals)
model = torch.jit.script(MyModule())
box_regression = torch.randn([4, 4])
proposal = [torch.randn(2, 4), torch.randn(2, 4)]
with self.assertRaises(RuntimeError):
f = io.BytesIO()
torch.onnx.export(
model,
(box_regression, proposal),
f,
)
def test_initializer_sequence(self):
class MyModule(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
test_model = MyModule(3, 4, 10)
state_dict_list = [k for (k, v) in test_model.state_dict().items()]
named_params_list = [k for (k, v) in test_model.named_parameters()]
x = torch.randn(32, 3)
f = io.BytesIO()
torch.onnx.export(test_model, (x,), f, do_constant_folding=False)
loaded_model = onnx.load_from_string(f.getvalue())
actual_list = [p.name for p in loaded_model.graph.initializer]
assert actual_list == state_dict_list, (
"Initializers' sequence is not as same as state_dict(). Expected: ("
+ ", ".join(state_dict_list)
+ "). Actual:("
+ ", ".join(actual_list)
+ ")."
)
assert actual_list == named_params_list, (
"Initializers' sequence is not as same as named_parameters(). Expected: ("
+ ", ".join(named_params_list)
+ "). Actual:("
+ ", ".join(actual_list)
+ ")."
)
def test_initializer_sequence_script_model(self):
def list_is_expected(short_list, long_list) -> bool:
if len(short_list) > len(long_list):
return False
for i in range(len(short_list)):
if short_list[i] not in long_list[i]:
return False
return True
def loop(x, y):
for i in range(int(y)):
x = x + i
return x
class MyModule(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, x, y):
x = loop(x, y)
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
test_model = torch.jit.script(MyModule(3, 4, 10))
state_dict_list = [k for (k, v) in test_model.state_dict().items()]
named_params_list = [k for (k, v) in test_model.named_parameters()]
x = torch.ones(2, 3, dtype=torch.float)
y = torch.tensor(5, dtype=torch.long)
f = io.BytesIO()
torch.onnx.export(test_model, (x, y), f, do_constant_folding=False)
loaded_model = onnx.load_from_string(f.getvalue())
actual_list = [p.name for p in loaded_model.graph.initializer]
assert list_is_expected(state_dict_list, actual_list), (
"ScriptModel - Initializers' sequence is not as same as state_dict(). Expected: ("
+ ", ".join(state_dict_list)
+ "). Actual:("
+ ", ".join(actual_list)
+ ")."
)
assert list_is_expected(named_params_list, actual_list), (
"ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: ("
+ ", ".join(named_params_list)
+ "). Actual:("
+ ", ".join(actual_list)
+ ")."
)
def test_shape_value_map(self):
class RSoftMax(torch.nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
return x
radix = 2
cardinality = 1
x = torch.randn(10, 1, 128, 1)
f = io.BytesIO()
torch.onnx.export(
RSoftMax(radix, cardinality),
(x,),
f,
input_names=["x"],
dynamic_axes={"x": [0]},
)
loaded_model = onnx.load_from_string(f.getvalue())
self.assertEqual(
loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128
)
def test_onnx_proto_checker(self):
class Model(torch.nn.Module):
def forward(self, x):
return 2 * x
x = torch.randn(1, 2, 3, requires_grad=True)
f = io.BytesIO()
torch.onnx.export(Model(), (x,), f)
model = onnx.load(f)
model.ir_version = 0
def check_proto():
torch._C._check_onnx_proto(model.SerializeToString())
self.assertRaises(RuntimeError, check_proto)
def test_maintain_dynamic_shapes_of_unreliable_nodes(self):
def symbolic_pythonop(g, *args, **kwargs):
return g.op("com.microsoft::PythonOp")
torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1)
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1)
# necessay parameters for transformer embeddings
hidden_size = 48
max_position_embeddings = 32
batch_size = 2
# issue found that autograd.function making downstream
# node unreliable but with static shape. The issue was first
# discovered with using Apex FusedLayerNorm in Transformers
class CustomLayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, embedding):
layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12)
return layer_norm(embedding)
class EmbeddingModule(torch.nn.Module):
def forward(
self,
embeddings=None,
):
embedding_output = CustomLayerNorm.apply(embeddings)
query = embedding_output.transpose(0, 1)
target_len, batch_size, embedding_dim = query.size()
# Reshape is used for consuming batch_size, and if it is static,
# this will be a Constant node in the graph
query = query.reshape(target_len, batch_size, embedding_dim)
return query
embeddings = torch.randn(batch_size, max_position_embeddings, hidden_size)
f = io.BytesIO()
torch.onnx.export(
EmbeddingModule().eval(),
(embeddings,),
f,
input_names=["embeddings"],
dynamic_axes={
"embeddings": {
0: "batch_size",
1: "max_position_embeddings",
2: "hidden_size",
}
},
custom_opsets={"com.microsoft": 1},
)
model = onnx.load(io.BytesIO(f.getvalue()))
# If there is a constant node with dim=3 and max_position_embeddings,
# batch_size, hidden_size as shape, it means the shape becomes static.
# Normally, with dynamic batch size, this constant node should not exist.
const_node = [n for n in model.graph.node if n.op_type == "Constant"]
self.assertNotEqual(len(const_node), 0)
for node in const_node:
for a in node.attribute:
if a.name == "value":
shape = onnx.numpy_helper.to_array(a.t)
self.assertNotEqual(
shape.tolist(),
[max_position_embeddings, batch_size, hidden_size],
)
def test_is_fp_for_C_TypeList(self):
class M(torch.nn.Module):
def forward(self, x):
x = x.squeeze(1)
w = x.shape[2]
pos = x.view(2, -1).argmax(1)
x_int = pos % w
y_int = (pos - x_int) // w
return y_int, x_int
model = torch.jit.script(M())
inputs = torch.randn(2, 4, 6)
f = io.BytesIO()
torch.onnx.export(
model, inputs, f, dynamic_axes={"x": [0, 1]}, input_names=["x"]
)
def test_dropout_script(self):
eg = torch.zeros(1, 2, 3, requires_grad=True)
@jit_utils._trace(eg)
def foo(x):
x = torch.neg(x)
return F.dropout(x)
class MyDrop(torch.nn.Module):
def forward(self, x):
return foo(x)
f = io.BytesIO()
with warnings.catch_warnings(record=True):
torch.onnx.export(MyDrop(), (eg,), f)
def test_pack_padded_pad_packed_trace(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
T, B, C = 3, 5, 7
class PadPackedWrapper(torch.nn.Module):
def forward(self, x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, _ = pad_packed_sequence(x)
return x
x = np.ones((T, B, C))
seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
# set padding value so we can test equivalence
for b in range(B):
if seq_lens[b] < T:
x[seq_lens[b] :, b, :] = 0
seq_lens = torch.from_numpy(seq_lens)
x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
m = PadPackedWrapper()
m_traced = torch.jit.trace(
m,
(
x,
seq_lens,
),
)
y = m(x, seq_lens)
loss = torch.sum(y)
loss.backward()
grad = x.grad.clone()
x.grad.zero_()
y_traced = m_traced(x, seq_lens)
loss_traced = torch.sum(y_traced)
loss_traced.backward()
grad_traced = x.grad.clone()
self.assertEqual(y_traced, x)
self.assertEqual(y_traced, y)
self.assertEqual(grad, grad_traced)
f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f)
# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
@common_utils.suppress_warnings
def test_rnn_trace_override(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
num_layers = 3
T, B, C = 11, 5, 7
class RNNTraceWrapper(torch.nn.Module):
def __init__(self, cell_type):
super().__init__()
if cell_type == "RNN":
self.rnn = torch.nn.RNN(
input_size=C, hidden_size=C, num_layers=num_layers
)
elif cell_type == "LSTM":
self.rnn = torch.nn.LSTM(
input_size=C, hidden_size=C, num_layers=num_layers
)
elif cell_type == "GRU":
self.rnn = torch.nn.GRU(
input_size=C, hidden_size=C, num_layers=num_layers
)
def forward(self, x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, _ = self.rnn(x)
x, _ = pad_packed_sequence(x)
return x
for cell_type in ["RNN", "LSTM", "GRU"]:
x = torch.ones(T, B, C, requires_grad=True)
seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
m = RNNTraceWrapper(cell_type)
m_traced = torch.jit.trace(
m,
(
x,
seq_lens,
),
)
y = m(x, seq_lens)
loss = torch.sum(y)
loss.backward()
grad = x.grad.clone()
x.grad.zero_()
y_traced = m_traced(x, seq_lens)
loss_traced = torch.sum(y_traced)
loss_traced.backward()
grad_traced = x.grad.clone()
self.assertEqual(y_traced, y)
self.assertEqual(grad, grad_traced)
f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f)
def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
num_layers = 3
T, B, C = 11, 5, 7
mask_start_point = 0
class LSTMTraceWrapper(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.rnn = torch.nn.LSTM(
input_size=C, hidden_size=C, num_layers=num_layers
)
def forward(self, x, seq_lens):
mask = torch.arange(mask_start_point, x.shape[1])
seq_lens = seq_lens[mask]
x = pack_padded_sequence(x, seq_lens)
# Calculate sizes and prepare views to our zero buffer to pass as hx
max_batch_size = x.batch_sizes[0]
hx = torch.randn(num_layers, max_batch_size, C)
cx = torch.randn(num_layers, max_batch_size, C)
x, _ = self.rnn(x, (hx, cx))
x, _ = pad_packed_sequence(x)
return x
x = torch.ones(T, B, C)
# length 5 because of B
seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
m = LSTMTraceWrapper()
f = io.BytesIO()
torch.onnx.export(
m,
(x, seq_lens),
f,
verbose=True,
input_names=["input", "seq_len"],
dynamic_axes={"input": {1: "B"}},
)
onnx_proto = onnx.load_model_from_string(f.getvalue())
# the first argument in onnx::Range should be constant node with value 0
const_node = []
constant_input_name = None
for n in onnx_proto.graph.node:
if n.op_type == "Constant":
const_node.append(n)
elif n.op_type == "Range":
constant_input_name = n.input[0]
self.assertNotEqual(constant_input_name, None)
self.assertNotEqual(len(const_node), 0)
value = None
for n in const_node:
if n.output[0] == constant_input_name:
value = np.frombuffer(n.attribute[0].t.raw_data, dtype=np.int64)
self.assertEqual(value, 0)
def test_trace_fork_wait_inline_onnx(self):
def fork_body(x):
return torch.neg(x), torch.neg(x)
class MyMod(torch.nn.Module):
def forward(self, x):
fut = torch.jit._fork(fork_body, x)
val = torch.jit._wait(fut)
return val[1]
# smoke test for ONNX export
f = io.BytesIO()
torch.onnx.export(MyMod(), (torch.rand(3, 4),), f)
def test_trace_detach_onnx_erase(self):
class Mod(torch.nn.Module):
def forward(self, x, w):
return torch.matmul(x, w).detach()
f = io.BytesIO()
torch.onnx.export(Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)
def test_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg
x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")
def test_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")
def test_onnx_aten_fallback_must_not_fallback(self):
# For BUILD_CAFFE2=0, aten fallback only when not exportable
class ONNXExportable(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.fc1 = torch.nn.Linear(12, 8)
self.fc2 = torch.nn.Linear(8, 4)
self.fc3 = torch.nn.Linear(4, 6)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = x.view((-1, 12))
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
h = F.relu(self.fc3(h))
h = self.dequant(h)
return h
dummy_input = torch.randn(12)
f = io.BytesIO()
torch.onnx.export(
ONNXExportable(),
(dummy_input,),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
all_aten_nodes = [
p
for p in onnx_model.graph.node
if p.op_type == "ATen" and p.domain == "org.pytorch.aten"
]
self.assertEqual(len(all_aten_nodes), 0)
def test_cat_with_empty_tensor(self):
class NoopConcat(torch.nn.Module):
def forward(self, x):
return torch.cat((torch.Tensor([]), x))
x = torch.randn(4, 5, 6)
# TODO: Parametrize this test for opset_version
for opset_version in {9, 11}:
f = io.BytesIO()
torch.onnx.export(NoopConcat(), (x,), f, opset_version=opset_version)
loaded_model = onnx.load_from_string(f.getvalue())
self.assertEqual(
len(loaded_model.graph.output[0].type.tensor_type.shape.dim), 3
)
for idx, dim in enumerate(x.shape):
self.assertEqual(
loaded_model.graph.output[0]
.type.tensor_type.shape.dim[idx]
.dim_value,
dim,
)
def test_col2im(self):
# This test can be moved to test/onnx/test_pytorch_onnx_onnxruntime.py when ORT implement ::Col2Im
# Random batched RGB 32x32 image-shaped input tensor of batch size 64
original_image_inputs = torch.randn((64, 3, 32, 32))
output_size = tuple(original_image_inputs.shape[2:])
kernel_size = (1, 2)
dilation = 3
padding = 2
stride = 1
model_im2col = torch.nn.Unfold(
kernel_size, dilation=dilation, padding=padding, stride=stride
)
blocks = model_im2col(original_image_inputs)
model = torch.nn.Fold(
output_size=output_size,
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride,
)
f = io.BytesIO()
torch.onnx.export(model, (blocks,), f, opset_version=18)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(onnx_model.graph.node[-1].op_type, "Col2Im")
self.assertEqual(onnx_model.graph.node[-1].domain, "")
self.assertEqual(len(onnx_model.graph.node[-1].input), 3)
self.assertEqual(onnx_model.graph.node[-1].attribute[0].name, "dilations")
self.assertEqual(onnx_model.graph.node[-1].attribute[1].name, "pads")
self.assertEqual(onnx_model.graph.node[-1].attribute[2].name, "strides")
@unittest.skipIf(
not torch.hub._check_module_exists("torch_scatter"),
"torch_scatter not installed.",
)
def test_random_namespace_custom_op_is_onnx_exportable(self):
from torch_scatter import scatter_max # type: ignore[import]
class MyModel(torch.nn.Module):
def forward(self, src: torch.Tensor, idx: torch.Tensor):
return scatter_max(src, idx)
m = MyModel().eval()
src = torch.ones([3, 10], dtype=torch.float32)
idx = torch.randint(0, 4, [3, 10], dtype=torch.long)
def sym_scatter_max(g, src, index, dim, out, dim_size):
return g.op(
"torch_scatter::scatter_max", src, index, dim_size_i=-1, outputs=2
)
torch.onnx.register_custom_op_symbolic(
"torch_scatter::scatter_max", sym_scatter_max, 1
)
f = io.BytesIO()
with torch.no_grad():
torch.onnx.export(
m,
(src, idx),
f,
opset_version=13,
custom_opsets={"torch_scatter": 1},
do_constant_folding=True,
)
@common_utils.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
def test_fp8_export(self, fp8_dtype: torch.dtype):
class Model(torch.nn.Module):
def forward(self, x):
return x.to(torch.float32)
x = torch.randn(2, 3).to(fp8_dtype)
f = io.BytesIO()
torch.onnx.export(Model(), x, f, opset_version=19)
onnx.checker.check_model(f.getvalue())
onnx_type = {
torch.float8_e4m3fn: 17,
torch.float8_e5m2: 19,
} # From https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3#L512-L521
loaded_model = onnx.load_from_string(f.getvalue())
self.assertEqual(
loaded_model.graph.input[0].type.tensor_type.elem_type, onnx_type[fp8_dtype]
)
class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):
model.qconfig = torch.ao.quantization.default_qconfig
model = torch.ao.quantization.prepare(model)
model = torch.ao.quantization.convert(model)
_ = model(data)
input_names = ["x"]
def _export_to_onnx(model, input, input_names):
traced = torch.jit.trace(model, input)
buf = io.BytesIO()
torch.jit.save(traced, buf)
buf.seek(0)
model = torch.jit.load(buf)
f = io.BytesIO()
torch.onnx.export(
model,
input,
f,
input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
opset_version=9,
)
_export_to_onnx(model, data, input_names)
@common_quantization.skipIfNoFBGEMM
@unittest.skip(
"onnx opset9 does not support quantize_per_tensor and caffe2 \
does not support conv3d"
)
def test_lower_graph_conv3d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv3d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)
@pytorch_test_common.skipIfNoCuda
def test_composed_layer_norm_small_eps_fp16_keep_double(self):
class Net(torch.nn.Module):
def __init__(self, C):
super().__init__()
self.layer_norm = torch.nn.LayerNorm(C, eps=1e-8)
def forward(self, x):
return self.layer_norm(x)
N, C = 8, 4
model = Net(C).cuda().half()
x = torch.randn(N, C).cuda().half()
f = io.BytesIO()
torch.onnx.export(model, (x,), f, opset_version=14)
onnx_model = onnx.load_from_string(f.getvalue())
const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"]
self.assertNotEqual(len(const_node), 0)
double_type_count = 0
for node in const_node:
for a in node.attribute:
# EPS constant should be in double type
if a.name == "value" and a.t.data_type == 11:
double_type_count += 1
self.assertNotEqual(double_type_count, 0)
@pytorch_test_common.skipIfNoCuda
def test_aten_device_with_index(self):
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
model = torch.compile(model, backend="onnxrt")
model = model.eval()
device = "cuda:0"
model = model.to(device)
ids = tokenizer.batch_encode_plus(["This is a test"], return_tensors="pt").to(
device
)
with torch.no_grad():
_ = model(
input_ids=ids["input_ids"],
attention_mask=ids["attention_mask"],
decoder_input_ids=ids["input_ids"],
decoder_attention_mask=ids["attention_mask"],
)
def test_aten_linalg_vector_norm_with_reducel2(self):
class Net(torch.nn.Module):
def forward(self, x):
x = F.normalize(x)
return x
f = io.BytesIO()
torch.onnx.export(Net(), (torch.randn(1, 2, 2),), f)
onnx_model = onnx.load_from_string(f.getvalue())
onnx_nodes = [n.op_type for n in onnx_model.graph.node]
self.assertIn("ReduceL2", onnx_nodes)
if __name__ == "__main__":
common_utils.run_tests()