[BE] detect CXX pytree requirement with TorchVersion (#151102)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151102
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2025-04-26 11:34:37 +08:00
committed by PyTorch MergeBot
parent 8cb6957e01
commit f1d636f85b
2 changed files with 24 additions and 21 deletions

View File

@ -19,24 +19,8 @@ from collections.abc import Iterable
from typing import Any, Callable, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, TypeIs
import optree
from torch._vendor.packaging.version import Version
# Keep the version in sync with torch.utils._cxx_pytree!
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
raise ImportError(
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "
"of PyTorch. To use it, please upgrade your optree package to >= 0.13.0"
)
del Version
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion
from torch.utils._pytree import (
is_namedtuple as is_namedtuple,
is_namedtuple_class as is_namedtuple_class,
@ -48,6 +32,20 @@ from torch.utils._pytree import (
)
# Do not try to import `optree` package if the static version check already fails.
if not python_pytree._cxx_pytree_dynamo_traceable:
raise ImportError(
f"{__name__} depends on `optree>={python_pytree._optree_minimum_version}`, "
"which is an optional dependency of PyTorch. "
"To use it, please upgrade your optree package via "
"`python3 -m pip install --upgrade optree`"
)
import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
__all__ = [
"PyTree",
"Context",
@ -90,6 +88,9 @@ __all__ = [
]
# In-tree installation may have VCS-based versioning. Update the previous static version.
python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently