Gracefully handle optree less than minimum version, part 2 (#151257)

If optree is less than the minimum version, we should pretend it doesn't
exist.

The problem right now is:
- Install optree==0.12.1
- `import torch._dynamo`
- This raise an error "min optree version is 0.13.0"

The fix is to pretend optree doesn't exist if it is less than the min
version.

There are ways to clean up this PR more (e.g. have a single source of
truth for the version, some of the variables are redundant), but I am
trying to reduce the risk as much as possible for this to go into 2.7.

Test Plan:

I verified the above problem was fixed. Also tried some other things,
like the following, which now gives the expected behavior.
```py
>>> import torch
>>> import optree
>>> optree.__version__
'0.12.1'
>>> import torch._dynamo
>>> import torch._dynamo.polyfills.pytree
>>> import torch.utils._pytree
>>> import torch.utils._cxx_pytree
ImportError: torch.utils._cxx_pytree depends on optree, which is
an optional dependency of PyTorch. To u
se it, please upgrade your optree package to >= 0.13.0
```

I also audited all non-test callsites of optree and torch.utils._cxx_pytree.
Follow along with me:

optree imports
- torch.utils._cxx_pytree. This is fine.
- [guarded by check] f76b7ef33c/torch/_dynamo/polyfills/pytree.py (L29-L31)

_cxx_pytree imports
- [guarded by check] torch.utils._pytree (changed in this PR)
- [guarded by check] torch/_dynamo/polyfills/pytree.py (changed in this PR)
- [guarded by try-catch] f76b7ef33c/torch/distributed/_functional_collectives.py (L17)
- [guarded by try-catch] f76b7ef33c/torch/distributed/tensor/_op_schema.py (L15)
- [guarded by try-catch] f76b7ef33c/torch/distributed/tensor/_dispatch.py (L35)
- [guarded by try-catch] f76b7ef33c/torch/_dynamo/variables/user_defined.py (L94)
- [guarded by try-catch] f76b7ef33c/torch/distributed/tensor/experimental/_func_map.py (L14)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151257
Approved by: https://github.com/malfet, https://github.com/XuehaiPan
This commit is contained in:
rzou
2025-04-14 12:31:13 -07:00
committed by PyTorch MergeBot
parent 12cb11a268
commit f1f18c75c9
3 changed files with 16 additions and 6 deletions

View File

@ -20,8 +20,6 @@ if TYPE_CHECKING:
from collections.abc import Iterable
from typing_extensions import Self
from torch.utils._cxx_pytree import PyTree
__all__: list[str] = []
@ -32,6 +30,9 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import torch.utils._cxx_pytree as cxx_pytree
if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
@substitute_in_graph(
optree._C.is_dict_insertion_ordered,
can_constant_fold_through=True,

View File

@ -24,6 +24,7 @@ import optree
from torch._vendor.packaging.version import Version
# Keep the version in sync with torch.utils._cxx_pytree!
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
raise ImportError(
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "

View File

@ -173,12 +173,20 @@ SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {}
try:
_optree_version = importlib.metadata.version("optree")
except importlib.metadata.PackageNotFoundError:
# optree was not imported
# No optree package found
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
else:
# optree was imported
_cxx_pytree_exists = True
_cxx_pytree_dynamo_traceable = True
from torch._vendor.packaging.version import Version
# Keep this in sync with torch.utils._cxx_pytree!
if Version(_optree_version) < Version("0.13.0"):
# optree package less than our required minimum version.
# Pretend the optree package doesn't exist.
# NB: We will raise ImportError if the user directly tries to
# `import torch.utils._cxx_pytree` (look in that file for the check).
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
else:
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = True
_cxx_pytree_imported = False
_cxx_pytree_pending_imports: list[Any] = []