Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"

This reverts commit eeeb40b32717bab75bd7d8f28f8343385688b3ab.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to Reverting this pr as the one under it in the stack is causing regressions in torchrec ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1806044435))
This commit is contained in:
PyTorch MergeBot
2023-11-10 16:30:36 +00:00
parent d4c810cc11
commit 23e0923c74
10 changed files with 53 additions and 233 deletions

View File

@ -22,7 +22,7 @@ KEEP_ELLIPSES = 2
KEEP_NAME_AND_ELLIPSES = 3
PRUNE_FUNCTIONS = {
"torch/utils/_pytree/api/python.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
"torch/utils/_pytree.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
"torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
@ -699,14 +699,14 @@ class TestProfilerTree(TestCase):
...
aten::add
test_profiler_tree.py(...): __torch_dispatch__
torch/utils/_pytree/api/python.py(...): tree_map
torch/utils/_pytree.py(...): tree_map
...
torch/utils/_pytree/api/python.py(...): tree_map
torch/utils/_pytree.py(...): tree_map
...
torch/_ops.py(...): __call__
<built-in method of PyCapsule object at 0xXXXXXXXXXXXX>
aten::add
torch/utils/_pytree/api/python.py(...): tree_map
torch/utils/_pytree.py(...): tree_map
...
torch/profiler/profiler.py(...): __exit__
torch/profiler/profiler.py(...): stop

View File

@ -4,8 +4,8 @@ import unittest
from collections import namedtuple, OrderedDict
import torch
import torch.utils._pytree.api.cxx as cxx_pytree
import torch.utils._pytree.api.python as py_pytree
import torch.utils._cxx_pytree as cxx_pytree
import torch.utils._pytree as py_pytree
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,

View File

@ -12,7 +12,7 @@ from torch.fx.experimental.proxy_tensor import (
from torch._custom_ops import impl_abstract
try:
from torch.utils._pytree.api.cxx import tree_map_only
from torch.utils._cxx_pytree import tree_map_only
except ImportError:
from torch.utils._pytree import tree_map_only # type: ignore[no-redef]

View File

@ -23,9 +23,9 @@ from torch.distributed._tensor.redistribute import redistribute_local_tensor
from torch.distributed._tensor.sharding_prop import ShardingPropagator
try:
from torch.utils._pytree.api import cxx as pytree
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils._pytree.api import python as pytree # type: ignore[no-redef]
from torch.utils import _pytree as pytree # type: ignore[no-redef]
aten = torch.ops.aten

View File

@ -7,7 +7,7 @@ from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec
try:
from torch.utils._pytree.api.cxx import tree_map_only, TreeSpec
from torch.utils._cxx_pytree import tree_map_only, TreeSpec
except ImportError:
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
tree_map_only,

View File

@ -13,7 +13,18 @@ collection support for PyTorch APIs.
"""
import functools
from typing import Any, Callable, Iterable, List, Optional, overload, Tuple, Type, Union
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
overload,
Tuple,
Type,
TypeVar,
Union,
)
import torch
@ -23,20 +34,6 @@ if torch._running_with_deploy():
import optree
from optree import PyTreeSpec # direct import for type annotations
from .typing import (
Context,
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
PyTree,
R,
S,
T,
ToDumpableContextFn,
U,
UnflattenFunc,
)
__all__ = [
"PyTree",
@ -67,8 +64,21 @@ __all__ = [
]
T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
R = TypeVar("R")
Context = Optional[Any]
PyTree = Any
TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[Iterable, Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -213,7 +223,7 @@ def register_pytree_node(
namespace=namespace,
)
from . import python
from . import _pytree as python
python._private_register_pytree_node(
cls,
@ -883,7 +893,7 @@ def treespec_dumps(treespec: TreeSpec) -> str:
f"treespec_dumps(spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}."
)
from .python import (
from ._pytree import (
tree_structure as _tree_structure,
treespec_dumps as _treespec_dumps,
)
@ -894,7 +904,7 @@ def treespec_dumps(treespec: TreeSpec) -> str:
def treespec_loads(serialized: str) -> TreeSpec:
"""Deserialize a treespec from a JSON string."""
from .python import (
from ._pytree import (
tree_unflatten as _tree_unflatten,
treespec_loads as _treespec_loads,
)

View File

@ -32,23 +32,10 @@ from typing import (
overload,
Tuple,
Type,
TypeVar,
Union,
)
from .typing import (
Context,
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
PyTree,
R,
S,
T,
ToDumpableContextFn,
U,
UnflattenFunc,
)
__all__ = [
"PyTree",
@ -79,10 +66,22 @@ __all__ = [
]
T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
R = TypeVar("R")
DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[Iterable, Context], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
ToStrFunc = Callable[["TreeSpec", List[str]], str]
MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]]
@ -163,7 +162,7 @@ def _register_pytree_node(
)
try:
from . import cxx
from . import _cxx_pytree as cxx
except ImportError:
pass
else:

View File

@ -1,88 +0,0 @@
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `tree_map` to map a function over all Tensors inside some nested
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
"""
from .api import (
Context,
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
LeafSpec,
PyTree,
register_pytree_node,
ToDumpableContextFn,
tree_all,
tree_all_only,
tree_any,
tree_any_only,
tree_flatten,
tree_leaves,
tree_map,
tree_map_,
tree_map_only,
tree_map_only_,
tree_structure,
tree_unflatten,
TreeSpec,
treespec_dumps,
treespec_loads,
treespec_pprint,
UnflattenFunc,
)
from .api.python import ( # used by internals and/or third-party packages
_broadcast_to_and_flatten,
_dict_flatten,
_dict_unflatten,
_get_node_type,
_is_leaf,
_list_flatten,
_list_unflatten,
_namedtuple_flatten,
_namedtuple_unflatten,
_odict_flatten,
_odict_unflatten,
_register_pytree_node,
_tuple_flatten,
_tuple_unflatten,
arg_tree_leaves,
SUPPORTED_NODES,
)
__all__ = [
"PyTree",
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"TreeSpec",
"LeafSpec",
"register_pytree_node",
"tree_flatten",
"tree_unflatten",
"tree_leaves",
"tree_structure",
"tree_map",
"tree_map_",
"tree_map_only",
"tree_map_only_",
"tree_all",
"tree_any",
"tree_all_only",
"tree_any_only",
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
]

View File

@ -1,72 +0,0 @@
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `tree_map` to map a function over all Tensors inside some nested
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
"""
from .python import (
LeafSpec,
register_pytree_node,
tree_all,
tree_all_only,
tree_any,
tree_any_only,
tree_flatten,
tree_leaves,
tree_map,
tree_map_,
tree_map_only,
tree_map_only_,
tree_structure,
tree_unflatten,
TreeSpec,
treespec_dumps,
treespec_loads,
treespec_pprint,
)
from .typing import (
Context,
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
PyTree,
ToDumpableContextFn,
UnflattenFunc,
)
__all__ = [
"PyTree",
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"TreeSpec",
"LeafSpec",
"register_pytree_node",
"tree_flatten",
"tree_unflatten",
"tree_leaves",
"tree_structure",
"tree_map",
"tree_map_",
"tree_map_only",
"tree_map_only_",
"tree_all",
"tree_any",
"tree_all_only",
"tree_any_only",
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
]

View File

@ -1,29 +0,0 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Iterable, List, Tuple, TypeVar
__all__ = [
"Context",
"PyTree",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
]
T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
R = TypeVar("R")
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[Iterable, Context], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]