mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Dynamo] only import einops if version is lower than 0.7.0 (#142847)"
This reverts commit a96387a481633389a6b5a5ac7b8406e9216f320e. Reverted https://github.com/pytorch/pytorch/pull/142847 on behalf of https://github.com/huydhn due to This has been reverted internally D67436053 ([comment](https://github.com/pytorch/pytorch/pull/142847#issuecomment-2555942351))
This commit is contained in:
@ -2,12 +2,10 @@
|
||||
# ruff: noqa: TCH004
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from torch._vendor.packaging.version import Version
|
||||
from torch.utils._contextlib import _DecoratorContextManager
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
@ -610,33 +608,27 @@ def mark_static_address(t, guard=True):
|
||||
# Note: this carefully avoids eagerly import einops.
|
||||
# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2
|
||||
def _allow_in_graph_einops():
|
||||
mod = sys.modules.get("einops")
|
||||
if mod is None:
|
||||
return
|
||||
else:
|
||||
# version > 0.7.0 does allow_in_graph out of tree
|
||||
if Version(mod.__version__) < Version("0.7.0"):
|
||||
import einops
|
||||
import einops
|
||||
|
||||
try:
|
||||
# requires einops > 0.6.1, torch >= 2.0
|
||||
from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
|
||||
_ops_were_registered_in_torchdynamo,
|
||||
)
|
||||
try:
|
||||
# requires einops > 0.6.1, torch >= 2.0
|
||||
from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
|
||||
_ops_were_registered_in_torchdynamo,
|
||||
)
|
||||
|
||||
# einops > 0.6.1 will call the op registration logic as it is imported.
|
||||
except ImportError:
|
||||
# einops <= 0.6.1
|
||||
allow_in_graph(einops.rearrange)
|
||||
allow_in_graph(einops.reduce)
|
||||
if hasattr(einops, "repeat"):
|
||||
allow_in_graph(einops.repeat) # available since einops 0.2.0
|
||||
if hasattr(einops, "einsum"):
|
||||
allow_in_graph(einops.einsum) # available since einops 0.5.0
|
||||
if hasattr(einops, "pack"):
|
||||
allow_in_graph(einops.pack) # available since einops 0.6.0
|
||||
if hasattr(einops, "unpack"):
|
||||
allow_in_graph(einops.unpack) # available since einops 0.6.0
|
||||
# einops > 0.6.1 will call the op registration logic as it is imported.
|
||||
except ImportError:
|
||||
# einops <= 0.6.1
|
||||
allow_in_graph(einops.rearrange)
|
||||
allow_in_graph(einops.reduce)
|
||||
if hasattr(einops, "repeat"):
|
||||
allow_in_graph(einops.repeat) # available since einops 0.2.0
|
||||
if hasattr(einops, "einsum"):
|
||||
allow_in_graph(einops.einsum) # available since einops 0.5.0
|
||||
if hasattr(einops, "pack"):
|
||||
allow_in_graph(einops.pack) # available since einops 0.6.0
|
||||
if hasattr(einops, "unpack"):
|
||||
allow_in_graph(einops.unpack) # available since einops 0.6.0
|
||||
|
||||
|
||||
trace_rules.add_module_init_func("einops", _allow_in_graph_einops)
|
||||
|
Reference in New Issue
Block a user