Add __ge__ to TorchVersion (#64565)

Summary:
This PR adds greater equal comparison so that not the base class's (str) comparison method is used.
This is necessary for a correct comparison with a version string.

Previously the following was the case:
```py
>>> torch.__version__
'1.10.0.dev20210830+cpu'
>>> torch.__version__>"1.9"
True
>>> torch.__version__>="1.9"
False  # Wrong output since the base class (str) was used for __ge__ comparison
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64565

Reviewed By: raghuramank100

Differential Revision: D30790463

Pulled By: mrshenli

fbshipit-source-id: 79c680f8b448001b34d3e5d5332124a78bea4e34
This commit is contained in:
Hendrik Schröter
2021-09-07 20:14:08 -07:00
committed by Facebook GitHub Bot
parent 81fe2c5e49
commit 5721205417

View File

@ -65,4 +65,13 @@ class TorchVersion(str):
# version like 'parrot'
return super().__eq__(cmp)
def __ge__(self, cmp):
try:
return Version(self).__ge__(self._convert_to_version(cmp))
except InvalidVersion:
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return super().__ge__(cmp)
__version__ = TorchVersion(internal_version)