[ONNX] Wrap test decorators with functools.wraps (#78254)

- Decorates test wrappers with `functools.wraps` to preserve the test method names (previously the names become "wrapper_xxx", which prevents the parameterized tests from getting the correct names.)
- Allows skip decorators to accept kwargs so multiple decorators can be used together
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78254
Approved by: https://github.com/BowenBao
This commit is contained in:
Justin Chu
2022-05-25 22:17:25 +00:00
committed by PyTorch MergeBot
parent 716f76716a
commit 5e03dfd36d

View File

@ -50,12 +50,13 @@ skipIfNoBFloat16Cuda = _skipper(
# smaller than the currently tested opset_version
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
def skip_dec(func):
def wrapper(self):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version < min_opset_version:
raise unittest.SkipTest(
f"Unsupported opset_version: {self.opset_version} < {min_opset_version}"
)
return func(self)
return func(self, *args, **kwargs)
return wrapper
@ -65,12 +66,13 @@ def skipIfUnsupportedMinOpsetVersion(min_opset_version):
# skips tests for all versions above max_opset_version.
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
def skip_dec(func):
def wrapper(self):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version > max_opset_version:
raise unittest.SkipTest(
f"Unsupported opset_version: {self.opset_version} > {max_opset_version}"
)
return func(self)
return func(self, *args, **kwargs)
return wrapper
@ -80,12 +82,13 @@ def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
# skips tests for all opset versions.
def skipForAllOpsetVersions():
def skip_dec(func):
def wrapper(self):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version:
raise unittest.SkipTest(
"Skip verify test for unsupported opset_version"
)
return func(self)
return func(self, *args, **kwargs)
return wrapper
@ -95,9 +98,10 @@ def skipForAllOpsetVersions():
# skips tests for scripting.
def skipScriptTest(min_opset_version=float("inf")):
def script_dec(func):
def wrapper(self):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.is_script_test_enabled = self.opset_version >= min_opset_version
return func(self)
return func(self, *args, **kwargs)
return wrapper
@ -109,12 +113,13 @@ def skipScriptTest(min_opset_version=float("inf")):
# (for example, an op was modified but the change is not supported in caffe2)
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
def skip_dec(func):
def wrapper(self):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.opset_version in unsupported_opset_versions:
raise unittest.SkipTest(
"Skip verify test for unsupported opset_version"
)
return func(self)
return func(self, *args, **kwargs)
return wrapper