diff --git a/torch/torch_version.py b/torch/torch_version.py index bb01294aebf9..4998c557d409 100644 --- a/torch/torch_version.py +++ b/torch/torch_version.py @@ -65,4 +65,13 @@ class TorchVersion(str): # version like 'parrot' return super().__eq__(cmp) + + def __ge__(self, cmp): + try: + return Version(self).__ge__(self._convert_to_version(cmp)) + except InvalidVersion: + # Fall back to regular string comparison if dealing with an invalid + # version like 'parrot' + return super().__ge__(cmp) + __version__ = TorchVersion(internal_version)