Files
pytorch/torch/torch_version.py
Aaron Gokaslan 7f65a20884 [BE]: Enable ruff SLOT checks (#146276)
This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276
Approved by: https://github.com/aorenste
2025-02-04 19:18:23 +00:00

67 lines
2.5 KiB
Python

from collections.abc import Iterable
from typing import Any
from torch._vendor.packaging.version import InvalidVersion, Version
from torch.version import __version__ as internal_version
__all__ = ["TorchVersion"]
class TorchVersion(str):
"""A string with magic powers to compare to both Version and iterables!
Prior to 1.10.0 torch.__version__ was stored as a str and so many did
comparisons against torch.__version__ as if it were a str. In order to not
break them we have TorchVersion which masquerades as a str while also
having the ability to compare against both packaging.version.Version as
well as tuples of values, eg. (1, 2, 1)
Examples:
Comparing a TorchVersion object to a Version object
TorchVersion('1.10.0a') > Version('1.10.0a')
Comparing a TorchVersion object to a Tuple object
TorchVersion('1.10.0a') > (1, 2) # 1.2
TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
Comparing a TorchVersion object against a string
TorchVersion('1.10.0a') > '1.2'
TorchVersion('1.10.0a') > '1.2.1'
"""
__slots__ = ()
# fully qualified type names here to appease mypy
def _convert_to_version(self, inp: Any) -> Any:
if isinstance(inp, Version):
return inp
elif isinstance(inp, str):
return Version(inp)
elif isinstance(inp, Iterable):
# Ideally this should work for most cases by attempting to group
# the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
# Examples:
# * (1) -> Version("1")
# * (1, 20) -> Version("1.20")
# * (1, 20, 1) -> Version("1.20.1")
return Version(".".join(str(item) for item in inp))
else:
raise InvalidVersion(inp)
def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
try:
return getattr(Version(self), method)(self._convert_to_version(cmp))
except BaseException as e:
if not isinstance(e, InvalidVersion):
raise
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return getattr(super(), method)(cmp)
for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
setattr(
TorchVersion,
cmp_method,
lambda x, y, method=cmp_method: x._cmp_wrapper(y, method),
)
__version__ = TorchVersion(internal_version)