mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit eeeb40b32717bab75bd7d8f28f8343385688b3ab. Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to Reverting this pr as the one under it in the stack is causing regressions in torchrec ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1806044435))
This commit is contained in:
@ -22,7 +22,7 @@ KEEP_ELLIPSES = 2
|
||||
KEEP_NAME_AND_ELLIPSES = 3
|
||||
|
||||
PRUNE_FUNCTIONS = {
|
||||
"torch/utils/_pytree/api/python.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
|
||||
"torch/utils/_pytree.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
|
||||
"torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
|
||||
"torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
|
||||
"torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
|
||||
@ -699,14 +699,14 @@ class TestProfilerTree(TestCase):
|
||||
...
|
||||
aten::add
|
||||
test_profiler_tree.py(...): __torch_dispatch__
|
||||
torch/utils/_pytree/api/python.py(...): tree_map
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/utils/_pytree/api/python.py(...): tree_map
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/_ops.py(...): __call__
|
||||
<built-in method of PyCapsule object at 0xXXXXXXXXXXXX>
|
||||
aten::add
|
||||
torch/utils/_pytree/api/python.py(...): tree_map
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/profiler/profiler.py(...): __exit__
|
||||
torch/profiler/profiler.py(...): stop
|
||||
|
@ -4,8 +4,8 @@ import unittest
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree.api.cxx as cxx_pytree
|
||||
import torch.utils._pytree.api.python as py_pytree
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
import torch.utils._pytree as py_pytree
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
@ -12,7 +12,7 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
from torch._custom_ops import impl_abstract
|
||||
|
||||
try:
|
||||
from torch.utils._pytree.api.cxx import tree_map_only
|
||||
from torch.utils._cxx_pytree import tree_map_only
|
||||
except ImportError:
|
||||
from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
|
||||
|
||||
|
@ -23,9 +23,9 @@ from torch.distributed._tensor.redistribute import redistribute_local_tensor
|
||||
from torch.distributed._tensor.sharding_prop import ShardingPropagator
|
||||
|
||||
try:
|
||||
from torch.utils._pytree.api import cxx as pytree
|
||||
from torch.utils import _cxx_pytree as pytree
|
||||
except ImportError:
|
||||
from torch.utils._pytree.api import python as pytree # type: ignore[no-redef]
|
||||
from torch.utils import _pytree as pytree # type: ignore[no-redef]
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
@ -7,7 +7,7 @@ from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
||||
|
||||
try:
|
||||
from torch.utils._pytree.api.cxx import tree_map_only, TreeSpec
|
||||
from torch.utils._cxx_pytree import tree_map_only, TreeSpec
|
||||
except ImportError:
|
||||
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
|
||||
tree_map_only,
|
||||
|
@ -13,7 +13,18 @@ collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Iterable, List, Optional, overload, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
@ -23,20 +34,6 @@ if torch._running_with_deploy():
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
|
||||
from .typing import (
|
||||
Context,
|
||||
DumpableContext,
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
PyTree,
|
||||
R,
|
||||
S,
|
||||
T,
|
||||
ToDumpableContextFn,
|
||||
U,
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTree",
|
||||
@ -67,8 +64,21 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
U = TypeVar("U")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
Context = Optional[Any]
|
||||
PyTree = Any
|
||||
TreeSpec = PyTreeSpec
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
||||
UnflattenFunc = Callable[[Iterable, Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
|
||||
|
||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
@ -213,7 +223,7 @@ def register_pytree_node(
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
from . import python
|
||||
from . import _pytree as python
|
||||
|
||||
python._private_register_pytree_node(
|
||||
cls,
|
||||
@ -883,7 +893,7 @@ def treespec_dumps(treespec: TreeSpec) -> str:
|
||||
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
from .python import (
|
||||
from ._pytree import (
|
||||
tree_structure as _tree_structure,
|
||||
treespec_dumps as _treespec_dumps,
|
||||
)
|
||||
@ -894,7 +904,7 @@ def treespec_dumps(treespec: TreeSpec) -> str:
|
||||
|
||||
def treespec_loads(serialized: str) -> TreeSpec:
|
||||
"""Deserialize a treespec from a JSON string."""
|
||||
from .python import (
|
||||
from ._pytree import (
|
||||
tree_unflatten as _tree_unflatten,
|
||||
treespec_loads as _treespec_loads,
|
||||
)
|
@ -32,23 +32,10 @@ from typing import (
|
||||
overload,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .typing import (
|
||||
Context,
|
||||
DumpableContext,
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
PyTree,
|
||||
R,
|
||||
S,
|
||||
T,
|
||||
ToDumpableContextFn,
|
||||
U,
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTree",
|
||||
@ -79,10 +66,22 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
U = TypeVar("U")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
|
||||
NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
|
||||
|
||||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
||||
UnflattenFunc = Callable[[Iterable, Context], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
ToStrFunc = Callable[["TreeSpec", List[str]], str]
|
||||
MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]]
|
||||
|
||||
@ -163,7 +162,7 @@ def _register_pytree_node(
|
||||
)
|
||||
|
||||
try:
|
||||
from . import cxx
|
||||
from . import _cxx_pytree as cxx
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
@ -1,88 +0,0 @@
|
||||
"""
|
||||
Contains utility functions for working with nested python data structures.
|
||||
|
||||
A *pytree* is Python nested data structure. It is a tree in the sense that
|
||||
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
|
||||
Python values. Furthermore, a pytree should not contain reference cycles.
|
||||
|
||||
pytrees are useful for working with nested collections of Tensors. For example,
|
||||
one can use `tree_map` to map a function over all Tensors inside some nested
|
||||
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
|
||||
inside some nested collection. pytrees are helpful for implementing nested
|
||||
collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
from .api import (
|
||||
Context,
|
||||
DumpableContext,
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
LeafSpec,
|
||||
PyTree,
|
||||
register_pytree_node,
|
||||
ToDumpableContextFn,
|
||||
tree_all,
|
||||
tree_all_only,
|
||||
tree_any,
|
||||
tree_any_only,
|
||||
tree_flatten,
|
||||
tree_leaves,
|
||||
tree_map,
|
||||
tree_map_,
|
||||
tree_map_only,
|
||||
tree_map_only_,
|
||||
tree_structure,
|
||||
tree_unflatten,
|
||||
TreeSpec,
|
||||
treespec_dumps,
|
||||
treespec_loads,
|
||||
treespec_pprint,
|
||||
UnflattenFunc,
|
||||
)
|
||||
from .api.python import ( # used by internals and/or third-party packages
|
||||
_broadcast_to_and_flatten,
|
||||
_dict_flatten,
|
||||
_dict_unflatten,
|
||||
_get_node_type,
|
||||
_is_leaf,
|
||||
_list_flatten,
|
||||
_list_unflatten,
|
||||
_namedtuple_flatten,
|
||||
_namedtuple_unflatten,
|
||||
_odict_flatten,
|
||||
_odict_unflatten,
|
||||
_register_pytree_node,
|
||||
_tuple_flatten,
|
||||
_tuple_unflatten,
|
||||
arg_tree_leaves,
|
||||
SUPPORTED_NODES,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTree",
|
||||
"Context",
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"register_pytree_node",
|
||||
"tree_flatten",
|
||||
"tree_unflatten",
|
||||
"tree_leaves",
|
||||
"tree_structure",
|
||||
"tree_map",
|
||||
"tree_map_",
|
||||
"tree_map_only",
|
||||
"tree_map_only_",
|
||||
"tree_all",
|
||||
"tree_any",
|
||||
"tree_all_only",
|
||||
"tree_any_only",
|
||||
"treespec_dumps",
|
||||
"treespec_loads",
|
||||
"treespec_pprint",
|
||||
]
|
@ -1,72 +0,0 @@
|
||||
"""
|
||||
Contains utility functions for working with nested python data structures.
|
||||
|
||||
A *pytree* is Python nested data structure. It is a tree in the sense that
|
||||
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
|
||||
Python values. Furthermore, a pytree should not contain reference cycles.
|
||||
|
||||
pytrees are useful for working with nested collections of Tensors. For example,
|
||||
one can use `tree_map` to map a function over all Tensors inside some nested
|
||||
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
|
||||
inside some nested collection. pytrees are helpful for implementing nested
|
||||
collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
from .python import (
|
||||
LeafSpec,
|
||||
register_pytree_node,
|
||||
tree_all,
|
||||
tree_all_only,
|
||||
tree_any,
|
||||
tree_any_only,
|
||||
tree_flatten,
|
||||
tree_leaves,
|
||||
tree_map,
|
||||
tree_map_,
|
||||
tree_map_only,
|
||||
tree_map_only_,
|
||||
tree_structure,
|
||||
tree_unflatten,
|
||||
TreeSpec,
|
||||
treespec_dumps,
|
||||
treespec_loads,
|
||||
treespec_pprint,
|
||||
)
|
||||
from .typing import (
|
||||
Context,
|
||||
DumpableContext,
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
PyTree,
|
||||
ToDumpableContextFn,
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTree",
|
||||
"Context",
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"register_pytree_node",
|
||||
"tree_flatten",
|
||||
"tree_unflatten",
|
||||
"tree_leaves",
|
||||
"tree_structure",
|
||||
"tree_map",
|
||||
"tree_map_",
|
||||
"tree_map_only",
|
||||
"tree_map_only_",
|
||||
"tree_all",
|
||||
"tree_any",
|
||||
"tree_all_only",
|
||||
"tree_any_only",
|
||||
"treespec_dumps",
|
||||
"treespec_loads",
|
||||
"treespec_pprint",
|
||||
]
|
@ -1,29 +0,0 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Iterable, List, Tuple, TypeVar
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Context",
|
||||
"PyTree",
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
U = TypeVar("U")
|
||||
R = TypeVar("R")
|
||||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
||||
UnflattenFunc = Callable[[Iterable, Context], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
Reference in New Issue
Block a user