mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add __ge__ to TorchVersion (#64565)
Summary: This PR adds greater equal comparison so that not the base class's (str) comparison method is used. This is necessary for a correct comparison with a version string. Previously the following was the case: ```py >>> torch.__version__ '1.10.0.dev20210830+cpu' >>> torch.__version__>"1.9" True >>> torch.__version__>="1.9" False # Wrong output since the base class (str) was used for __ge__ comparison ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/64565 Reviewed By: raghuramank100 Differential Revision: D30790463 Pulled By: mrshenli fbshipit-source-id: 79c680f8b448001b34d3e5d5332124a78bea4e34
This commit is contained in:
committed by
Facebook GitHub Bot
parent
81fe2c5e49
commit
5721205417
@ -65,4 +65,13 @@ class TorchVersion(str):
|
||||
# version like 'parrot'
|
||||
return super().__eq__(cmp)
|
||||
|
||||
|
||||
def __ge__(self, cmp):
|
||||
try:
|
||||
return Version(self).__ge__(self._convert_to_version(cmp))
|
||||
except InvalidVersion:
|
||||
# Fall back to regular string comparison if dealing with an invalid
|
||||
# version like 'parrot'
|
||||
return super().__ge__(cmp)
|
||||
|
||||
__version__ = TorchVersion(internal_version)
|
||||
|
Reference in New Issue
Block a user