Files
pytorch/test/onnx/test_pytorch_common.py
BowenBao 679fc90cdb [ONNX] Support optional type (#68793) (#73284)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73284

Some important ops won't support optional type until opset 16,
so we can't fully test things end-to-end, but I believe this should
be all that's needed. Once ONNX Runtime supports opset 16,
we can do more testing and fix any remaining bugs.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D34625646

Pulled By: malfet

fbshipit-source-id: 537fcbc1e9d87686cc61f5bd66a997e99cec287b

Co-authored-by: BowenBao <bowbao@microsoft.com>
Co-authored-by: neginraoof <neginmr@utexas.edu>
Co-authored-by: Nikita Shulga <nshulga@fb.com>
(cherry picked from commit 822e79f31ae54d73407f34f166b654f4ba115ea5)
2022-05-04 20:24:30 +00:00

102 lines
3.3 KiB
Python

# Owner(s): ["module: onnx"]
import functools
import os
import unittest
import sys
import torch
import torch.autograd.function as function
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(-1, pytorch_test_dir)
from torch.testing._internal.common_utils import * # noqa: F401,F403
torch.set_default_tensor_type("torch.FloatTensor")
BATCH_SIZE = 2
RNN_BATCH_SIZE = 7
RNN_SEQUENCE_LENGTH = 11
RNN_INPUT_SIZE = 5
RNN_HIDDEN_SIZE = 3
def _skipper(condition, reason):
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if condition():
raise unittest.SkipTest(reason)
return f(*args, **kwargs)
return wrapper
return decorator
skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(),
"CUDA is not available")
skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"),
"Skip In Travis")
skipIfNoBFloat16Cuda = _skipper(lambda: not torch.cuda.is_bf16_supported(),
"BFloat16 CUDA is not available")
# skips tests for all versions below min_opset_version.
# if exporting the op is only supported after a specific version,
# add this wrapper to prevent running the test for opset_versions
# smaller than the currently tested opset_version
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
def skip_dec(func):
def wrapper(self):
if self.opset_version < min_opset_version:
raise unittest.SkipTest(f"Unsupported opset_version: {self.opset_version} < {min_opset_version}")
return func(self)
return wrapper
return skip_dec
# skips tests for all versions above max_opset_version.
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
def skip_dec(func):
def wrapper(self):
if self.opset_version > max_opset_version:
raise unittest.SkipTest(f"Unsupported opset_version: {self.opset_version} > {max_opset_version}")
return func(self)
return wrapper
return skip_dec
# skips tests for all opset versions.
def skipForAllOpsetVersions():
def skip_dec(func):
def wrapper(self):
if self.opset_version:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec
# skips tests for scripting.
def skipScriptTest(min_opset_version=float("inf")):
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = self.opset_version >= min_opset_version
return func(self)
return wrapper
return script_dec
# skips tests for opset_versions listed in unsupported_opset_versions.
# if the caffe2 test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in caffe2)
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
def skip_dec(func):
def wrapper(self):
if self.opset_version in unsupported_opset_versions:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec
def flatten(x):
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))