Revert "[Dynamo] only import einops if version is lower than 0.7.0 (#142847)"

This reverts commit a96387a481633389a6b5a5ac7b8406e9216f320e.

Reverted https://github.com/pytorch/pytorch/pull/142847 on behalf of https://github.com/huydhn due to This has been reverted internally D67436053 ([comment](https://github.com/pytorch/pytorch/pull/142847#issuecomment-2555942351))
This commit is contained in:
PyTorch MergeBot
2024-12-19 23:22:43 +00:00
parent d2b83aa122
commit 145fd5bad0

View File

@ -2,12 +2,10 @@
# ruff: noqa: TCH004
import functools
import inspect
import sys
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
import torch
from torch._vendor.packaging.version import Version
from torch.utils._contextlib import _DecoratorContextManager
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
@ -610,33 +608,27 @@ def mark_static_address(t, guard=True):
# Note: this carefully avoids eagerly import einops.
# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2
def _allow_in_graph_einops():
mod = sys.modules.get("einops")
if mod is None:
return
else:
# version > 0.7.0 does allow_in_graph out of tree
if Version(mod.__version__) < Version("0.7.0"):
import einops
import einops
try:
# requires einops > 0.6.1, torch >= 2.0
from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
_ops_were_registered_in_torchdynamo,
)
try:
# requires einops > 0.6.1, torch >= 2.0
from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
_ops_were_registered_in_torchdynamo,
)
# einops > 0.6.1 will call the op registration logic as it is imported.
except ImportError:
# einops <= 0.6.1
allow_in_graph(einops.rearrange)
allow_in_graph(einops.reduce)
if hasattr(einops, "repeat"):
allow_in_graph(einops.repeat) # available since einops 0.2.0
if hasattr(einops, "einsum"):
allow_in_graph(einops.einsum) # available since einops 0.5.0
if hasattr(einops, "pack"):
allow_in_graph(einops.pack) # available since einops 0.6.0
if hasattr(einops, "unpack"):
allow_in_graph(einops.unpack) # available since einops 0.6.0
# einops > 0.6.1 will call the op registration logic as it is imported.
except ImportError:
# einops <= 0.6.1
allow_in_graph(einops.rearrange)
allow_in_graph(einops.reduce)
if hasattr(einops, "repeat"):
allow_in_graph(einops.repeat) # available since einops 0.2.0
if hasattr(einops, "einsum"):
allow_in_graph(einops.einsum) # available since einops 0.5.0
if hasattr(einops, "pack"):
allow_in_graph(einops.pack) # available since einops 0.6.0
if hasattr(einops, "unpack"):
allow_in_graph(einops.unpack) # available since einops 0.6.0
trace_rules.add_module_init_func("einops", _allow_in_graph_einops)