Lazy import packaging in torch_version (#71345)

Summary:
As it is a pretty big package and to be used during normal
course of PyTorch initialization

Fixes https://github.com/pytorch/pytorch/issues/71280

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71345

Reviewed By: seemethere

Differential Revision: D33594547

Pulled By: malfet

fbshipit-source-id: e0abea82dbdc29914512b610692701140d3e68a2
(cherry picked from commit 1ff7f65cc1ad499a71457368894ca14bed069749)
This commit is contained in:
Nikita Shulga
2022-01-18 12:37:48 -08:00
committed by PyTorch MergeBot
parent efd274bbcb
commit a986154950

View File

@ -1,12 +1,35 @@
from typing import Any, Iterable
from pkg_resources import packaging # type: ignore[attr-defined]
Version = packaging.version.Version
InvalidVersion = packaging.version.InvalidVersion
from .version import __version__ as internal_version
class _LazyImport:
"""Wraps around classes lazy imported from packaging.version
Output of the function v in following snippets are identical:
from packaging.version import Version
def v():
return Version('1.2.3')
and
Versoin = _LazyImport('Version')
def v():
return Version('1.2.3')
The difference here is that in later example imports
do not happen until v is called
"""
def __init__(self, cls_name: str) -> None:
self._cls_name = cls_name
def get_cls(self):
from pkg_resources import packaging # type: ignore[attr-defined]
return getattr(packaging.version, self._cls_name)
def __call__(self, *args, **kwargs):
return self.get_cls()(*args, **kwargs)
def __instancecheck__(self, obj):
return isinstance(obj, self.get_cls())
Version = _LazyImport("Version")
InvalidVersion = _LazyImport("InvalidVersion")
class TorchVersion(str):
"""A string with magic powers to compare to both Version and iterables!
@ -25,8 +48,9 @@ class TorchVersion(str):
TorchVersion('1.10.0a') > '1.2'
TorchVersion('1.10.0a') > '1.2.1'
"""
# fully qualified type names here to appease mypy
def _convert_to_version(self, inp: Any) -> Any:
if isinstance(inp, Version):
if isinstance(inp, Version.get_cls()):
return inp
elif isinstance(inp, str):
return Version(inp)
@ -44,7 +68,9 @@ class TorchVersion(str):
def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
try:
return getattr(Version(self), method)(self._convert_to_version(cmp))
except InvalidVersion:
except BaseException as e:
if not isinstance(e, InvalidVersion.get_cls()):
raise
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return getattr(super(), method)(cmp)