mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Gracefully handle optree less than minimum version, part 2 (#151257)
If optree is less than the minimum version, we should pretend it doesn't exist. The problem right now is: - Install optree==0.12.1 - `import torch._dynamo` - This raise an error "min optree version is 0.13.0" The fix is to pretend optree doesn't exist if it is less than the min version. There are ways to clean up this PR more (e.g. have a single source of truth for the version, some of the variables are redundant), but I am trying to reduce the risk as much as possible for this to go into 2.7. Test Plan: I verified the above problem was fixed. Also tried some other things, like the following, which now gives the expected behavior. ```py >>> import torch >>> import optree >>> optree.__version__ '0.12.1' >>> import torch._dynamo >>> import torch._dynamo.polyfills.pytree >>> import torch.utils._pytree >>> import torch.utils._cxx_pytree ImportError: torch.utils._cxx_pytree depends on optree, which is an optional dependency of PyTorch. To u se it, please upgrade your optree package to >= 0.13.0 ``` I also audited all non-test callsites of optree and torch.utils._cxx_pytree. Follow along with me: optree imports - torch.utils._cxx_pytree. This is fine. - [guarded by check]f76b7ef33c/torch/_dynamo/polyfills/pytree.py (L29-L31)
_cxx_pytree imports - [guarded by check] torch.utils._pytree (changed in this PR) - [guarded by check] torch/_dynamo/polyfills/pytree.py (changed in this PR) - [guarded by try-catch]f76b7ef33c/torch/distributed/_functional_collectives.py (L17)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/_op_schema.py (L15)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/_dispatch.py (L35)
- [guarded by try-catch]f76b7ef33c/torch/_dynamo/variables/user_defined.py (L94)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/experimental/_func_map.py (L14)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151257 Approved by: https://github.com/malfet, https://github.com/XuehaiPan
This commit is contained in:
@ -20,8 +20,6 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from typing_extensions import Self
|
||||
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@ -32,6 +30,9 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
|
||||
@substitute_in_graph(
|
||||
optree._C.is_dict_insertion_ordered,
|
||||
can_constant_fold_through=True,
|
||||
|
@ -24,6 +24,7 @@ import optree
|
||||
from torch._vendor.packaging.version import Version
|
||||
|
||||
|
||||
# Keep the version in sync with torch.utils._cxx_pytree!
|
||||
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
|
||||
raise ImportError(
|
||||
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "
|
||||
|
@ -173,12 +173,20 @@ SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {}
|
||||
try:
|
||||
_optree_version = importlib.metadata.version("optree")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# optree was not imported
|
||||
# No optree package found
|
||||
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
|
||||
else:
|
||||
# optree was imported
|
||||
_cxx_pytree_exists = True
|
||||
_cxx_pytree_dynamo_traceable = True
|
||||
from torch._vendor.packaging.version import Version
|
||||
|
||||
# Keep this in sync with torch.utils._cxx_pytree!
|
||||
if Version(_optree_version) < Version("0.13.0"):
|
||||
# optree package less than our required minimum version.
|
||||
# Pretend the optree package doesn't exist.
|
||||
# NB: We will raise ImportError if the user directly tries to
|
||||
# `import torch.utils._cxx_pytree` (look in that file for the check).
|
||||
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
|
||||
else:
|
||||
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = True
|
||||
|
||||
_cxx_pytree_imported = False
|
||||
_cxx_pytree_pending_imports: list[Any] = []
|
||||
|
Reference in New Issue
Block a user