mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Wrap test decorators with functools.wraps (#78254)
- Decorates test wrappers with `functools.wraps` to preserve the test method names (previously the names become "wrapper_xxx", which prevents the parameterized tests from getting the correct names.) - Allows skip decorators to accept kwargs so multiple decorators can be used together Pull Request resolved: https://github.com/pytorch/pytorch/pull/78254 Approved by: https://github.com/BowenBao
This commit is contained in:
committed by
PyTorch MergeBot
parent
716f76716a
commit
5e03dfd36d
@ -50,12 +50,13 @@ skipIfNoBFloat16Cuda = _skipper(
|
||||
# smaller than the currently tested opset_version
|
||||
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
|
||||
def skip_dec(func):
|
||||
def wrapper(self):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.opset_version < min_opset_version:
|
||||
raise unittest.SkipTest(
|
||||
f"Unsupported opset_version: {self.opset_version} < {min_opset_version}"
|
||||
)
|
||||
return func(self)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -65,12 +66,13 @@ def skipIfUnsupportedMinOpsetVersion(min_opset_version):
|
||||
# skips tests for all versions above max_opset_version.
|
||||
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
|
||||
def skip_dec(func):
|
||||
def wrapper(self):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.opset_version > max_opset_version:
|
||||
raise unittest.SkipTest(
|
||||
f"Unsupported opset_version: {self.opset_version} > {max_opset_version}"
|
||||
)
|
||||
return func(self)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -80,12 +82,13 @@ def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
|
||||
# skips tests for all opset versions.
|
||||
def skipForAllOpsetVersions():
|
||||
def skip_dec(func):
|
||||
def wrapper(self):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.opset_version:
|
||||
raise unittest.SkipTest(
|
||||
"Skip verify test for unsupported opset_version"
|
||||
)
|
||||
return func(self)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -95,9 +98,10 @@ def skipForAllOpsetVersions():
|
||||
# skips tests for scripting.
|
||||
def skipScriptTest(min_opset_version=float("inf")):
|
||||
def script_dec(func):
|
||||
def wrapper(self):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self.is_script_test_enabled = self.opset_version >= min_opset_version
|
||||
return func(self)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -109,12 +113,13 @@ def skipScriptTest(min_opset_version=float("inf")):
|
||||
# (for example, an op was modified but the change is not supported in caffe2)
|
||||
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
|
||||
def skip_dec(func):
|
||||
def wrapper(self):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.opset_version in unsupported_opset_versions:
|
||||
raise unittest.SkipTest(
|
||||
"Skip verify test for unsupported opset_version"
|
||||
)
|
||||
return func(self)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
Reference in New Issue
Block a user