mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] detect CXX pytree requirement with TorchVersion (#151102)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151102 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
8cb6957e01
commit
f1d636f85b
@ -19,24 +19,8 @@ from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import deprecated, TypeIs
|
||||
|
||||
import optree
|
||||
|
||||
from torch._vendor.packaging.version import Version
|
||||
|
||||
|
||||
# Keep the version in sync with torch.utils._cxx_pytree!
|
||||
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
|
||||
raise ImportError(
|
||||
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "
|
||||
"of PyTorch. To use it, please upgrade your optree package to >= 0.13.0"
|
||||
)
|
||||
|
||||
del Version
|
||||
|
||||
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
from torch.utils._pytree import (
|
||||
is_namedtuple as is_namedtuple,
|
||||
is_namedtuple_class as is_namedtuple_class,
|
||||
@ -48,6 +32,20 @@ from torch.utils._pytree import (
|
||||
)
|
||||
|
||||
|
||||
# Do not try to import `optree` package if the static version check already fails.
|
||||
if not python_pytree._cxx_pytree_dynamo_traceable:
|
||||
raise ImportError(
|
||||
f"{__name__} depends on `optree>={python_pytree._optree_minimum_version}`, "
|
||||
"which is an optional dependency of PyTorch. "
|
||||
"To use it, please upgrade your optree package via "
|
||||
"`python3 -m pip install --upgrade optree`"
|
||||
)
|
||||
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTree",
|
||||
"Context",
|
||||
@ -90,6 +88,9 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
# In-tree installation may have VCS-based versioning. Update the previous static version.
|
||||
python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]
|
||||
|
||||
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
|
||||
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently
|
||||
|
||||
|
||||
Reference in New Issue
Block a user