mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Gracefully handle optree less than minimum version, part 2 (#151257)
If optree is less than the minimum version, we should pretend it doesn't exist. The problem right now is: - Install optree==0.12.1 - `import torch._dynamo` - This raise an error "min optree version is 0.13.0" The fix is to pretend optree doesn't exist if it is less than the min version. There are ways to clean up this PR more (e.g. have a single source of truth for the version, some of the variables are redundant), but I am trying to reduce the risk as much as possible for this to go into 2.7. Test Plan: I verified the above problem was fixed. Also tried some other things, like the following, which now gives the expected behavior. ```py >>> import torch >>> import optree >>> optree.__version__ '0.12.1' >>> import torch._dynamo >>> import torch._dynamo.polyfills.pytree >>> import torch.utils._pytree >>> import torch.utils._cxx_pytree ImportError: torch.utils._cxx_pytree depends on optree, which is an optional dependency of PyTorch. To u se it, please upgrade your optree package to >= 0.13.0 ``` I also audited all non-test callsites of optree and torch.utils._cxx_pytree. Follow along with me: optree imports - torch.utils._cxx_pytree. This is fine. - [guarded by check]f76b7ef33c/torch/_dynamo/polyfills/pytree.py (L29-L31)
_cxx_pytree imports - [guarded by check] torch.utils._pytree (changed in this PR) - [guarded by check] torch/_dynamo/polyfills/pytree.py (changed in this PR) - [guarded by try-catch]f76b7ef33c/torch/distributed/_functional_collectives.py (L17)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/_op_schema.py (L15)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/_dispatch.py (L35)
- [guarded by try-catch]f76b7ef33c/torch/_dynamo/variables/user_defined.py (L94)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/experimental/_func_map.py (L14)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151257 Approved by: https://github.com/malfet, https://github.com/XuehaiPan
This commit is contained in:
@ -24,6 +24,7 @@ import optree
|
||||
from torch._vendor.packaging.version import Version
|
||||
|
||||
|
||||
# Keep the version in sync with torch.utils._cxx_pytree!
|
||||
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
|
||||
raise ImportError(
|
||||
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "
|
||||
|
Reference in New Issue
Block a user