mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] detect CXX pytree requirement with TorchVersion
(#151102)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151102 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
8cb6957e01
commit
f1d636f85b
@ -19,24 +19,8 @@ from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import deprecated, TypeIs
|
||||
|
||||
import optree
|
||||
|
||||
from torch._vendor.packaging.version import Version
|
||||
|
||||
|
||||
# Keep the version in sync with torch.utils._cxx_pytree!
|
||||
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
|
||||
raise ImportError(
|
||||
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "
|
||||
"of PyTorch. To use it, please upgrade your optree package to >= 0.13.0"
|
||||
)
|
||||
|
||||
del Version
|
||||
|
||||
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
from torch.utils._pytree import (
|
||||
is_namedtuple as is_namedtuple,
|
||||
is_namedtuple_class as is_namedtuple_class,
|
||||
@ -48,6 +32,20 @@ from torch.utils._pytree import (
|
||||
)
|
||||
|
||||
|
||||
# Do not try to import `optree` package if the static version check already fails.
|
||||
if not python_pytree._cxx_pytree_dynamo_traceable:
|
||||
raise ImportError(
|
||||
f"{__name__} depends on `optree>={python_pytree._optree_minimum_version}`, "
|
||||
"which is an optional dependency of PyTorch. "
|
||||
"To use it, please upgrade your optree package via "
|
||||
"`python3 -m pip install --upgrade optree`"
|
||||
)
|
||||
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTree",
|
||||
"Context",
|
||||
@ -90,6 +88,9 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
# In-tree installation may have VCS-based versioning. Update the previous static version.
|
||||
python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]
|
||||
|
||||
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
|
||||
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently
|
||||
|
||||
|
@ -43,6 +43,8 @@ from typing import (
|
||||
)
|
||||
from typing_extensions import deprecated, NamedTuple, Self
|
||||
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTree",
|
||||
@ -170,16 +172,16 @@ SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {}
|
||||
# NB: we try really hard to not import _cxx_pytree (which depends on optree)
|
||||
# as much as possible. This is for isolation: a user who is not using C++ pytree
|
||||
# shouldn't pay for it, and it helps makes things like cpython upgrades easier.
|
||||
_optree_minimum_version = _TorchVersion("0.13.0")
|
||||
try:
|
||||
_optree_version = importlib.metadata.version("optree")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# No optree package found
|
||||
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
|
||||
_optree_version = _TorchVersion("0.0.0a0")
|
||||
else:
|
||||
from torch._vendor.packaging.version import Version
|
||||
|
||||
# Keep this in sync with torch.utils._cxx_pytree!
|
||||
if Version(_optree_version) < Version("0.13.0"):
|
||||
_optree_version = _TorchVersion(_optree_version)
|
||||
if _optree_version < _optree_minimum_version:
|
||||
# optree package less than our required minimum version.
|
||||
# Pretend the optree package doesn't exist.
|
||||
# NB: We will raise ImportError if the user directly tries to
|
||||
|
Reference in New Issue
Block a user