mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Refactor repetitions into TorchVersion._cmp_wrapper` (#71344)
Summary: First step towards https://github.com/pytorch/pytorch/issues/71280 Pull Request resolved: https://github.com/pytorch/pytorch/pull/71344 Reviewed By: b0noI Differential Revision: D33594463 Pulled By: malfet fbshipit-source-id: 0295f0d9f0342f05a390b2bd4aa0a5958c76579b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c43e0286a9
commit
3ed27a96ed
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, Union
|
||||
from typing import Any, Iterable
|
||||
|
||||
from pkg_resources import packaging # type: ignore[attr-defined]
|
||||
|
||||
@ -25,8 +25,7 @@ class TorchVersion(str):
|
||||
TorchVersion('1.10.0a') > '1.2'
|
||||
TorchVersion('1.10.0a') > '1.2.1'
|
||||
"""
|
||||
# fully qualified type names here to appease mypy
|
||||
def _convert_to_version(self, inp: Union[packaging.version.Version, str, Iterable]) -> packaging.version.Version:
|
||||
def _convert_to_version(self, inp: Any) -> Any:
|
||||
if isinstance(inp, Version):
|
||||
return inp
|
||||
elif isinstance(inp, str):
|
||||
@ -42,44 +41,16 @@ class TorchVersion(str):
|
||||
else:
|
||||
raise InvalidVersion(inp)
|
||||
|
||||
def __gt__(self, cmp):
|
||||
def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
|
||||
try:
|
||||
return Version(self).__gt__(self._convert_to_version(cmp))
|
||||
return getattr(Version(self), method)(self._convert_to_version(cmp))
|
||||
except InvalidVersion:
|
||||
# Fall back to regular string comparison if dealing with an invalid
|
||||
# version like 'parrot'
|
||||
return super().__gt__(cmp)
|
||||
return getattr(super(), method)(cmp)
|
||||
|
||||
def __lt__(self, cmp):
|
||||
try:
|
||||
return Version(self).__lt__(self._convert_to_version(cmp))
|
||||
except InvalidVersion:
|
||||
# Fall back to regular string comparison if dealing with an invalid
|
||||
# version like 'parrot'
|
||||
return super().__lt__(cmp)
|
||||
|
||||
def __eq__(self, cmp):
|
||||
try:
|
||||
return Version(self).__eq__(self._convert_to_version(cmp))
|
||||
except InvalidVersion:
|
||||
# Fall back to regular string comparison if dealing with an invalid
|
||||
# version like 'parrot'
|
||||
return super().__eq__(cmp)
|
||||
|
||||
def __ge__(self, cmp):
|
||||
try:
|
||||
return Version(self).__ge__(self._convert_to_version(cmp))
|
||||
except InvalidVersion:
|
||||
# Fall back to regular string comparison if dealing with an invalid
|
||||
# version like 'parrot'
|
||||
return super().__ge__(cmp)
|
||||
|
||||
def __le__(self, cmp):
|
||||
try:
|
||||
return Version(self).__le__(self._convert_to_version(cmp))
|
||||
except InvalidVersion:
|
||||
# Fall back to regular string comparison if dealing with an invalid
|
||||
# version like 'parrot'
|
||||
return super().__le__(cmp)
|
||||
for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
|
||||
setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method))
|
||||
|
||||
__version__ = TorchVersion(internal_version)
|
||||
|
Reference in New Issue
Block a user