[BE] detect CXX pytree requirement with TorchVersion (#151102)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151102
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2025-04-26 11:34:37 +08:00
committed by PyTorch MergeBot
parent 8cb6957e01
commit f1d636f85b
2 changed files with 24 additions and 21 deletions

View File

@ -19,24 +19,8 @@ from collections.abc import Iterable
from typing import Any, Callable, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, TypeIs
import optree
from torch._vendor.packaging.version import Version
# Keep the version in sync with torch.utils._cxx_pytree!
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
raise ImportError(
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "
"of PyTorch. To use it, please upgrade your optree package to >= 0.13.0"
)
del Version
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion
from torch.utils._pytree import (
is_namedtuple as is_namedtuple,
is_namedtuple_class as is_namedtuple_class,
@ -48,6 +32,20 @@ from torch.utils._pytree import (
)
# Do not try to import `optree` package if the static version check already fails.
if not python_pytree._cxx_pytree_dynamo_traceable:
raise ImportError(
f"{__name__} depends on `optree>={python_pytree._optree_minimum_version}`, "
"which is an optional dependency of PyTorch. "
"To use it, please upgrade your optree package via "
"`python3 -m pip install --upgrade optree`"
)
import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
__all__ = [
"PyTree",
"Context",
@ -90,6 +88,9 @@ __all__ = [
]
# In-tree installation may have VCS-based versioning. Update the previous static version.
python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently

View File

@ -43,6 +43,8 @@ from typing import (
)
from typing_extensions import deprecated, NamedTuple, Self
from torch.torch_version import TorchVersion as _TorchVersion
__all__ = [
"PyTree",
@ -170,16 +172,16 @@ SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {}
# NB: we try really hard to not import _cxx_pytree (which depends on optree)
# as much as possible. This is for isolation: a user who is not using C++ pytree
# shouldn't pay for it, and it helps makes things like cpython upgrades easier.
_optree_minimum_version = _TorchVersion("0.13.0")
try:
_optree_version = importlib.metadata.version("optree")
except importlib.metadata.PackageNotFoundError:
# No optree package found
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
_optree_version = _TorchVersion("0.0.0a0")
else:
from torch._vendor.packaging.version import Version
# Keep this in sync with torch.utils._cxx_pytree!
if Version(_optree_version) < Version("0.13.0"):
_optree_version = _TorchVersion(_optree_version)
if _optree_version < _optree_minimum_version:
# optree package less than our required minimum version.
# Pretend the optree package doesn't exist.
# NB: We will raise ImportError if the user directly tries to