mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Lazy import packaging
in torch_version
(#71345)
Summary: As it is a pretty big package and to be used during normal course of PyTorch initialization Fixes https://github.com/pytorch/pytorch/issues/71280 Pull Request resolved: https://github.com/pytorch/pytorch/pull/71345 Reviewed By: seemethere Differential Revision: D33594547 Pulled By: malfet fbshipit-source-id: e0abea82dbdc29914512b610692701140d3e68a2 (cherry picked from commit 1ff7f65cc1ad499a71457368894ca14bed069749)
This commit is contained in:
committed by
PyTorch MergeBot
parent
efd274bbcb
commit
a986154950
@ -1,12 +1,35 @@
|
||||
from typing import Any, Iterable
|
||||
|
||||
from pkg_resources import packaging # type: ignore[attr-defined]
|
||||
|
||||
Version = packaging.version.Version
|
||||
InvalidVersion = packaging.version.InvalidVersion
|
||||
|
||||
from .version import __version__ as internal_version
|
||||
|
||||
class _LazyImport:
|
||||
"""Wraps around classes lazy imported from packaging.version
|
||||
Output of the function v in following snippets are identical:
|
||||
from packaging.version import Version
|
||||
def v():
|
||||
return Version('1.2.3')
|
||||
and
|
||||
Versoin = _LazyImport('Version')
|
||||
def v():
|
||||
return Version('1.2.3')
|
||||
The difference here is that in later example imports
|
||||
do not happen until v is called
|
||||
"""
|
||||
def __init__(self, cls_name: str) -> None:
|
||||
self._cls_name = cls_name
|
||||
|
||||
def get_cls(self):
|
||||
from pkg_resources import packaging # type: ignore[attr-defined]
|
||||
return getattr(packaging.version, self._cls_name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.get_cls()(*args, **kwargs)
|
||||
|
||||
def __instancecheck__(self, obj):
|
||||
return isinstance(obj, self.get_cls())
|
||||
|
||||
|
||||
Version = _LazyImport("Version")
|
||||
InvalidVersion = _LazyImport("InvalidVersion")
|
||||
|
||||
class TorchVersion(str):
|
||||
"""A string with magic powers to compare to both Version and iterables!
|
||||
@ -25,8 +48,9 @@ class TorchVersion(str):
|
||||
TorchVersion('1.10.0a') > '1.2'
|
||||
TorchVersion('1.10.0a') > '1.2.1'
|
||||
"""
|
||||
# fully qualified type names here to appease mypy
|
||||
def _convert_to_version(self, inp: Any) -> Any:
|
||||
if isinstance(inp, Version):
|
||||
if isinstance(inp, Version.get_cls()):
|
||||
return inp
|
||||
elif isinstance(inp, str):
|
||||
return Version(inp)
|
||||
@ -44,7 +68,9 @@ class TorchVersion(str):
|
||||
def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
|
||||
try:
|
||||
return getattr(Version(self), method)(self._convert_to_version(cmp))
|
||||
except InvalidVersion:
|
||||
except BaseException as e:
|
||||
if not isinstance(e, InvalidVersion.get_cls()):
|
||||
raise
|
||||
# Fall back to regular string comparison if dealing with an invalid
|
||||
# version like 'parrot'
|
||||
return getattr(super(), method)(cmp)
|
||||
|
Reference in New Issue
Block a user