[BE] Refactor repetitions into TorchVersion._cmp_wrapper` (#71344)

Summary:
First step towards https://github.com/pytorch/pytorch/issues/71280

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71344

Reviewed By: b0noI

Differential Revision: D33594463

Pulled By: malfet

fbshipit-source-id: 0295f0d9f0342f05a390b2bd4aa0a5958c76579b
This commit is contained in:
Nikita Shulga
2022-01-14 19:56:46 -08:00
committed by Facebook GitHub Bot
parent c43e0286a9
commit 3ed27a96ed

View File

@ -1,4 +1,4 @@
from typing import Iterable, Union
from typing import Any, Iterable
from pkg_resources import packaging # type: ignore[attr-defined]
@ -25,8 +25,7 @@ class TorchVersion(str):
TorchVersion('1.10.0a') > '1.2'
TorchVersion('1.10.0a') > '1.2.1'
"""
# fully qualified type names here to appease mypy
def _convert_to_version(self, inp: Union[packaging.version.Version, str, Iterable]) -> packaging.version.Version:
def _convert_to_version(self, inp: Any) -> Any:
if isinstance(inp, Version):
return inp
elif isinstance(inp, str):
@ -42,44 +41,16 @@ class TorchVersion(str):
else:
raise InvalidVersion(inp)
def __gt__(self, cmp):
def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
try:
return Version(self).__gt__(self._convert_to_version(cmp))
return getattr(Version(self), method)(self._convert_to_version(cmp))
except InvalidVersion:
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return super().__gt__(cmp)
return getattr(super(), method)(cmp)
def __lt__(self, cmp):
try:
return Version(self).__lt__(self._convert_to_version(cmp))
except InvalidVersion:
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return super().__lt__(cmp)
def __eq__(self, cmp):
try:
return Version(self).__eq__(self._convert_to_version(cmp))
except InvalidVersion:
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return super().__eq__(cmp)
def __ge__(self, cmp):
try:
return Version(self).__ge__(self._convert_to_version(cmp))
except InvalidVersion:
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return super().__ge__(cmp)
def __le__(self, cmp):
try:
return Version(self).__le__(self._convert_to_version(cmp))
except InvalidVersion:
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return super().__le__(cmp)
for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method))
__version__ = TorchVersion(internal_version)