Compare commits

...

182 Commits

Author SHA1 Message Date
9ebf4d1d17 Update
[ghstack-poisoned]
2025-11-06 01:17:09 +08:00
e6a9af94c9 Update (base update)
[ghstack-poisoned]
2025-11-06 01:17:09 +08:00
017eea3b7f Update
[ghstack-poisoned]
2025-11-04 21:47:46 +08:00
f0aa9ad2ce Update (base update)
[ghstack-poisoned]
2025-11-04 21:47:46 +08:00
ea7ce296b3 Update
[ghstack-poisoned]
2025-11-04 21:33:32 +08:00
d7e0635e61 Update (base update)
[ghstack-poisoned]
2025-11-04 21:33:32 +08:00
cc74263539 Update
[ghstack-poisoned]
2025-11-04 20:18:41 +08:00
e813b773ee Update (base update)
[ghstack-poisoned]
2025-11-04 20:18:41 +08:00
a4d0f125e0 Update
[ghstack-poisoned]
2025-11-04 17:07:29 +08:00
e52ebb0c10 Update (base update)
[ghstack-poisoned]
2025-11-04 17:07:29 +08:00
4bd6eb00c4 Update
[ghstack-poisoned]
2025-11-04 13:21:30 +08:00
2a577c1d91 Update (base update)
[ghstack-poisoned]
2025-11-04 13:02:08 +08:00
3dd66fdc20 Update
[ghstack-poisoned]
2025-11-04 13:02:08 +08:00
c9f3c833ef Update (base update)
[ghstack-poisoned]
2025-11-04 12:53:26 +08:00
e5a26f2430 Update
[ghstack-poisoned]
2025-11-04 12:53:26 +08:00
cd51ee276b Update (base update)
[ghstack-poisoned]
2025-10-29 21:23:30 +08:00
2f75d3aebe Update
[ghstack-poisoned]
2025-10-29 21:23:30 +08:00
95039738e3 Update (base update)
[ghstack-poisoned]
2025-10-29 21:02:54 +08:00
f44a8f5201 Update
[ghstack-poisoned]
2025-10-29 21:02:54 +08:00
d96b71c3ef Update (base update)
[ghstack-poisoned]
2025-10-11 21:35:31 +08:00
3f0985bf89 Update
[ghstack-poisoned]
2025-10-11 21:35:31 +08:00
29d16a10f3 Update (base update)
[ghstack-poisoned]
2025-10-08 22:52:13 +08:00
6676fe538f Update
[ghstack-poisoned]
2025-10-08 22:52:13 +08:00
c374b66c75 Update (base update)
[ghstack-poisoned]
2025-10-08 22:39:24 +08:00
70183368c6 Update
[ghstack-poisoned]
2025-10-08 22:39:24 +08:00
faa50fa6c4 Update (base update)
[ghstack-poisoned]
2025-09-19 18:03:46 +08:00
c601a1ea72 Update
[ghstack-poisoned]
2025-09-19 18:03:46 +08:00
dc53fc2af2 Update (base update)
[ghstack-poisoned]
2025-09-06 11:34:34 +08:00
bbb546f542 Update
[ghstack-poisoned]
2025-09-06 11:34:34 +08:00
26a1088f9f Update (base update)
[ghstack-poisoned]
2025-08-17 16:23:38 +08:00
a2d5216c04 Update
[ghstack-poisoned]
2025-08-17 16:23:38 +08:00
968b72ca2c Update (base update)
[ghstack-poisoned]
2025-08-09 02:51:18 +08:00
9f3385822d Update
[ghstack-poisoned]
2025-08-09 02:51:18 +08:00
54911834c4 Update (base update)
[ghstack-poisoned]
2025-07-31 15:19:09 +08:00
5ae581762b Update
[ghstack-poisoned]
2025-07-31 15:19:09 +08:00
1e1b37ed77 Update (base update)
[ghstack-poisoned]
2025-07-25 20:00:31 +08:00
332e835040 Update
[ghstack-poisoned]
2025-07-25 20:00:31 +08:00
8e2c6ff709 Update (base update)
[ghstack-poisoned]
2025-07-17 15:02:04 +08:00
d05480c236 Update
[ghstack-poisoned]
2025-07-17 15:02:04 +08:00
f589cb4a72 Update (base update)
[ghstack-poisoned]
2025-07-09 19:01:34 +08:00
3e66bd8fa8 Update
[ghstack-poisoned]
2025-07-09 19:01:34 +08:00
7aff2fc214 Update (base update)
[ghstack-poisoned]
2025-07-03 16:24:23 +08:00
369df36d49 Update
[ghstack-poisoned]
2025-07-03 16:24:23 +08:00
41663a247d Update (base update)
[ghstack-poisoned]
2025-06-28 20:59:47 +08:00
b27cc37252 Update
[ghstack-poisoned]
2025-06-28 20:59:47 +08:00
1c337ea84b Update (base update)
[ghstack-poisoned]
2025-06-27 21:27:45 +08:00
80e40a5976 Update
[ghstack-poisoned]
2025-06-27 21:27:45 +08:00
6d00bd774f Update (base update)
[ghstack-poisoned]
2025-06-23 22:51:21 +08:00
fb2b16422d Update
[ghstack-poisoned]
2025-06-23 22:51:21 +08:00
303e7afcd9 Update (base update)
[ghstack-poisoned]
2025-06-18 23:17:48 +08:00
2a9ceb2f1f Update
[ghstack-poisoned]
2025-06-18 23:17:48 +08:00
c2e1972c18 Update (base update)
[ghstack-poisoned]
2025-06-06 19:50:50 +08:00
cead985182 Update
[ghstack-poisoned]
2025-06-06 19:50:50 +08:00
b32f36ce35 Update (base update)
[ghstack-poisoned]
2025-05-31 21:59:59 +08:00
11d7c79cea Update
[ghstack-poisoned]
2025-05-31 21:59:59 +08:00
2c60570864 Update (base update)
[ghstack-poisoned]
2025-05-28 20:43:33 +08:00
979efcb825 Update
[ghstack-poisoned]
2025-05-28 20:43:33 +08:00
faea6584f8 Update (base update)
[ghstack-poisoned]
2025-05-16 11:37:32 +08:00
99836e07fb Update
[ghstack-poisoned]
2025-05-16 11:37:32 +08:00
3d383f42e9 Update (base update)
[ghstack-poisoned]
2025-05-14 20:35:01 +08:00
08d95fe5c4 Update
[ghstack-poisoned]
2025-05-14 20:35:01 +08:00
27ac21550e Update (base update)
[ghstack-poisoned]
2025-05-08 21:19:08 +08:00
e57894677e Update
[ghstack-poisoned]
2025-05-08 21:19:08 +08:00
28c48d60aa Update (base update)
[ghstack-poisoned]
2025-05-04 02:10:44 +08:00
41259dc86e Update
[ghstack-poisoned]
2025-05-04 02:10:44 +08:00
ab474eebfd Update (base update)
[ghstack-poisoned]
2025-05-03 02:34:22 +08:00
113b8306a8 Update
[ghstack-poisoned]
2025-05-03 02:34:22 +08:00
727a1aa849 Update (base update)
[ghstack-poisoned]
2025-05-03 01:14:43 +08:00
dbd47d2dae Update
[ghstack-poisoned]
2025-05-03 01:14:43 +08:00
c659299f60 Update (base update)
[ghstack-poisoned]
2025-05-03 00:45:00 +08:00
ba32ef92a6 Update
[ghstack-poisoned]
2025-05-03 00:45:00 +08:00
4fd958b566 Update (base update)
[ghstack-poisoned]
2025-05-03 00:40:33 +08:00
313bfeea17 Update
[ghstack-poisoned]
2025-05-03 00:40:33 +08:00
f3e5113185 Update (base update)
[ghstack-poisoned]
2025-05-02 02:30:03 +08:00
5710b9d4af Update
[ghstack-poisoned]
2025-05-02 02:30:03 +08:00
e4d59f3a0a Update (base update)
[ghstack-poisoned]
2025-05-02 02:25:06 +08:00
599b39676b Update
[ghstack-poisoned]
2025-05-02 02:25:06 +08:00
cd38ad58a5 Update (base update)
[ghstack-poisoned]
2025-05-02 01:44:37 +08:00
fffb30d445 Update
[ghstack-poisoned]
2025-05-02 01:44:37 +08:00
3e53b1fb27 Update (base update)
[ghstack-poisoned]
2025-05-02 01:39:06 +08:00
c309091454 Update
[ghstack-poisoned]
2025-05-02 01:39:06 +08:00
121f110b83 Update (base update)
[ghstack-poisoned]
2025-04-26 11:34:39 +08:00
bfe4839177 Update
[ghstack-poisoned]
2025-04-26 11:34:39 +08:00
7a962a2f0d Update (base update)
[ghstack-poisoned]
2025-04-23 21:35:46 +08:00
ad0668cbb8 Update
[ghstack-poisoned]
2025-04-23 21:35:46 +08:00
f6795a4922 Update (base update)
[ghstack-poisoned]
2025-04-15 22:19:49 +08:00
0d8d2a0360 Update
[ghstack-poisoned]
2025-04-15 22:19:49 +08:00
f98737b13a Update (base update)
[ghstack-poisoned]
2025-04-15 22:12:31 +08:00
898c2037ed Update
[ghstack-poisoned]
2025-04-15 22:12:31 +08:00
f643f2e78b Update (base update)
[ghstack-poisoned]
2025-04-15 22:10:41 +08:00
97cf576c2c Update
[ghstack-poisoned]
2025-04-15 22:10:41 +08:00
d5e568d6e0 Update (base update)
[ghstack-poisoned]
2025-04-15 22:03:25 +08:00
c9607615a4 Update
[ghstack-poisoned]
2025-04-15 22:03:25 +08:00
2549707053 Update (base update)
[ghstack-poisoned]
2025-04-11 19:04:56 +08:00
aa53182976 Update
[ghstack-poisoned]
2025-04-11 19:04:56 +08:00
3b8b6ef6fb Update (base update)
[ghstack-poisoned]
2025-04-11 18:35:20 +08:00
0625bbf0c7 Update
[ghstack-poisoned]
2025-04-11 18:35:20 +08:00
57f2575735 Update (base update)
[ghstack-poisoned]
2025-04-11 18:27:39 +08:00
ddb82b6b96 Update
[ghstack-poisoned]
2025-04-11 18:27:39 +08:00
d1aba50677 Update (base update)
[ghstack-poisoned]
2025-04-11 18:16:46 +08:00
760f4fb105 Update
[ghstack-poisoned]
2025-04-11 18:16:46 +08:00
46bb41bd37 Update (base update)
[ghstack-poisoned]
2025-04-10 17:25:15 +08:00
98984e1561 Update
[ghstack-poisoned]
2025-04-10 17:25:15 +08:00
d90e80dd35 Update (base update)
[ghstack-poisoned]
2025-04-07 22:41:39 +08:00
246b2fd7a0 Update
[ghstack-poisoned]
2025-04-07 22:41:39 +08:00
8c13a8323a Update (base update)
[ghstack-poisoned]
2025-04-05 23:26:58 +08:00
1fa721545a Update
[ghstack-poisoned]
2025-04-05 23:26:58 +08:00
f649b7bfbd Update (base update)
[ghstack-poisoned]
2025-04-03 23:11:59 +08:00
f0e9ee0bdc Update
[ghstack-poisoned]
2025-04-03 23:11:59 +08:00
45f34b99a9 Update (base update)
[ghstack-poisoned]
2025-04-03 22:22:53 +08:00
6e6343130f Update
[ghstack-poisoned]
2025-04-03 22:22:53 +08:00
53f09a5136 Update (base update)
[ghstack-poisoned]
2025-04-03 21:58:28 +08:00
5df659dcf8 Update
[ghstack-poisoned]
2025-04-03 21:58:28 +08:00
2e6d995297 Update
[ghstack-poisoned]
2025-04-02 00:14:59 +08:00
aa65799ee0 Update (base update)
[ghstack-poisoned]
2025-04-02 00:14:58 +08:00
02a382b7be Update
[ghstack-poisoned]
2025-03-21 00:08:15 +08:00
6abac60294 Update (base update)
[ghstack-poisoned]
2025-03-21 00:08:15 +08:00
4037b4fc22 Update
[ghstack-poisoned]
2025-03-14 12:47:28 +08:00
af5bc4e801 Update (base update)
[ghstack-poisoned]
2025-03-14 12:47:28 +08:00
108d7f193a Update
[ghstack-poisoned]
2025-03-13 04:41:40 +08:00
a8fb34cae5 Update (base update)
[ghstack-poisoned]
2025-03-13 04:41:40 +08:00
7b122690be Update
[ghstack-poisoned]
2025-03-07 18:09:20 +08:00
4caf34ab53 Update (base update)
[ghstack-poisoned]
2025-03-07 18:09:20 +08:00
3409a1e033 Update
[ghstack-poisoned]
2025-03-07 03:57:21 +08:00
66681eea1b Update (base update)
[ghstack-poisoned]
2025-03-07 03:57:21 +08:00
b7838168c3 Update
[ghstack-poisoned]
2025-03-07 03:19:28 +08:00
26b6913b3d Update (base update)
[ghstack-poisoned]
2025-03-07 03:19:28 +08:00
ca192e08bd Update
[ghstack-poisoned]
2025-03-06 21:45:27 +08:00
88a13fdc94 Update (base update)
[ghstack-poisoned]
2025-03-06 21:45:27 +08:00
d099b63055 Update
[ghstack-poisoned]
2025-03-05 20:36:42 +08:00
5fd8eb6fb8 Update (base update)
[ghstack-poisoned]
2025-03-05 20:36:41 +08:00
a4a2c1cffb Update
[ghstack-poisoned]
2025-03-05 20:19:47 +08:00
400df72bda Update (base update)
[ghstack-poisoned]
2025-03-05 20:19:47 +08:00
33f1963cc4 Update
[ghstack-poisoned]
2025-03-05 20:14:48 +08:00
ac68c29b4c Update (base update)
[ghstack-poisoned]
2025-03-05 20:14:48 +08:00
2832a51c54 Update
[ghstack-poisoned]
2025-03-05 01:46:28 +08:00
3b3af30466 Update (base update)
[ghstack-poisoned]
2025-03-05 00:41:45 +08:00
49df7a7617 Update
[ghstack-poisoned]
2025-03-05 00:41:45 +08:00
92176bfdff Update (base update)
[ghstack-poisoned]
2025-03-04 22:37:24 +08:00
2f41576c02 Update
[ghstack-poisoned]
2025-03-04 22:37:24 +08:00
0b4314efe2 Update (base update)
[ghstack-poisoned]
2025-03-04 19:05:38 +08:00
f092c0e8ce Update
[ghstack-poisoned]
2025-03-04 19:05:38 +08:00
0c224957b6 Update (base update)
[ghstack-poisoned]
2025-03-04 17:15:30 +08:00
a912652a09 Update
[ghstack-poisoned]
2025-03-04 17:15:30 +08:00
afb24fc9c1 Update (base update)
[ghstack-poisoned]
2025-03-04 11:43:02 +08:00
7017bcda6e Update
[ghstack-poisoned]
2025-03-04 11:43:02 +08:00
98b5a2fc77 Update (base update)
[ghstack-poisoned]
2025-03-04 04:46:17 +08:00
c2e569992b Update
[ghstack-poisoned]
2025-03-04 04:46:17 +08:00
4cdcd94061 Update (base update)
[ghstack-poisoned]
2025-03-04 03:31:33 +08:00
b2727a655f Update
[ghstack-poisoned]
2025-03-04 03:31:33 +08:00
3ae5f28df9 Update (base update)
[ghstack-poisoned]
2025-03-04 03:09:37 +08:00
130df3a1d6 Update
[ghstack-poisoned]
2025-03-04 03:09:37 +08:00
ffe60c3005 Update (base update)
[ghstack-poisoned]
2025-03-04 02:44:32 +08:00
f03956e5ae Update
[ghstack-poisoned]
2025-03-04 02:44:32 +08:00
782d543bf7 Update (base update)
[ghstack-poisoned]
2025-03-01 21:29:32 +08:00
a5d05d3d4e Update
[ghstack-poisoned]
2025-03-01 21:29:32 +08:00
7915feda28 Update (base update)
[ghstack-poisoned]
2025-03-01 19:17:04 +08:00
7c8aeffc4a Update
[ghstack-poisoned]
2025-03-01 19:17:04 +08:00
0518f254ed Update
[ghstack-poisoned]
2025-03-01 03:09:57 +08:00
ef3adf6eac Update
[ghstack-poisoned]
2025-03-01 02:56:44 +08:00
c7e5b56a7d Update
[ghstack-poisoned]
2025-03-01 01:36:03 +08:00
86001ed575 Update (base update)
[ghstack-poisoned]
2025-03-01 00:37:37 +08:00
b8eadce989 Update
[ghstack-poisoned]
2025-03-01 00:37:37 +08:00
06a7899877 Update (base update)
[ghstack-poisoned]
2025-03-01 00:36:38 +08:00
4b8ceddea7 Update
[ghstack-poisoned]
2025-03-01 00:36:38 +08:00
d55bb26dc5 Update (base update)
[ghstack-poisoned]
2025-03-01 00:24:08 +08:00
c93a280d5a Update
[ghstack-poisoned]
2025-03-01 00:24:08 +08:00
0e45af0a50 Update (base update)
[ghstack-poisoned]
2025-03-01 00:03:44 +08:00
f62063d25c Update
[ghstack-poisoned]
2025-03-01 00:03:44 +08:00
06d1a7fa0b Update (base update)
[ghstack-poisoned]
2025-02-28 23:53:49 +08:00
00a6482df6 Update
[ghstack-poisoned]
2025-02-28 23:53:49 +08:00
1d6e7ada97 Update (base update)
[ghstack-poisoned]
2025-02-28 23:45:41 +08:00
308f6a05ac Update
[ghstack-poisoned]
2025-02-28 23:45:41 +08:00
4e773f0037 Update
[ghstack-poisoned]
2025-02-28 23:39:42 +08:00
672915aece Update (base update)
[ghstack-poisoned]
2025-02-28 22:39:24 +08:00
08affe6664 Update
[ghstack-poisoned]
2025-02-28 22:39:24 +08:00
c36201b8b2 Update
[ghstack-poisoned]
2025-02-28 22:11:12 +08:00
347e4a1001 Update (base update)
[ghstack-poisoned]
2025-02-28 20:16:12 +08:00
c5f436b865 Update
[ghstack-poisoned]
2025-02-28 20:16:12 +08:00
51232f8496 Update
[ghstack-poisoned]
2025-02-28 19:35:17 +08:00
a9a875cb5c Update (base update)
[ghstack-poisoned]
2025-02-28 18:51:42 +08:00
3ae649453b Update
[ghstack-poisoned]
2025-02-28 18:51:42 +08:00
23 changed files with 644 additions and 144 deletions

View File

@ -195,6 +195,7 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/torch/utils/_pytree.py @XuehaiPan
/torch/utils/_cxx_pytree.py @XuehaiPan
/torch/utils/pytree/ @XuehaiPan
/torch/pytree.py @XuehaiPan
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
# Relating to libtorch ABI

View File

@ -59,6 +59,7 @@ torch.special <special>
torch.overrides
torch.nativert <nativert>
torch.package <package>
torch.pytree <pytree>
profiler
nn.init
nn.attention
@ -76,6 +77,7 @@ sparse
storage
torch.testing <testing>
torch.utils <utils>
torch.utils.pytree
torch.utils.benchmark <benchmark_utils>
torch.utils.checkpoint <checkpoint>
torch.utils.cpp_extension <cpp_extension>

7
docs/source/pytree.rst Normal file
View File

@ -0,0 +1,7 @@
torch.pytree
============
.. currentmodule:: torch.pytree
.. automodule:: torch.pytree
:members:

View File

@ -0,0 +1,7 @@
torch.utils.pytree
==================
.. currentmodule:: torch.utils.pytree
.. automodule:: torch.utils.pytree
:members:

View File

@ -29,6 +29,7 @@ files =
benchmarks/instruction_counts,
tools,
torch/profiler/_memory_profiler.py,
torch/utils/pytree/__init__.py,
torch/utils/_pytree.py,
torch/utils/_cxx_pytree.py,
torch/utils/benchmark/utils/common.py,

View File

@ -687,6 +687,28 @@
"kineto_available",
"record_function"
],
"torch.pytree": [
"PyTreeSpec",
"register_node",
"all",
"all_only",
"any",
"any_only",
"flatten",
"iter",
"leaves",
"map",
"map_",
"map_only",
"map_only_",
"structure",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance"
],
"torch.quantization": [
"ABC",
"DeQuantStub",

View File

@ -147,8 +147,8 @@ class GraphModule(torch.nn.Module):
t: "f32[10]" = l_x_ + l_y_
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
return (res,)
""", # NOQA: B950

View File

@ -40,6 +40,7 @@ import torch._inductor.test_case
import torch.onnx.operators
import torch.utils._pytree as python_pytree
import torch.utils.cpp_extension
import torch.utils.pytree as generic_pytree
from torch import Tensor
from torch._C import FileCheck
from torch._dynamo import allow_in_graph
@ -104,6 +105,7 @@ from torch.testing._internal.jit_utils import JitTestCase
pytree_modules = {
"generic": generic_pytree,
"python": python_pytree,
}
if python_pytree._cxx_pytree_dynamo_traceable:

View File

@ -8245,7 +8245,7 @@ graph():
%to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {})
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
%tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
%tree_unflatten : [num_users=1] = call_function[target=torch.utils.pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
return tree_unflatten""",
)

View File

@ -249,11 +249,11 @@ def forward(self, x, y):
_spec_0 = self._spec_0
_spec_1 = self._spec_1
_spec_4 = self._spec_4
tree_flatten = torch.utils._pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
tree_flatten = torch.utils.pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
getitem = tree_flatten[0]; tree_flatten = None
x = getitem[0]
y = getitem[1]; getitem = None
tree_unflatten_1 = torch.utils._pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
tree_unflatten_1 = torch.utils.pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
getitem_2 = getitem_1[0]
getitem_3 = getitem_1[1]; getitem_1 = None
@ -261,7 +261,7 @@ def forward(self, x, y):
bar = self.bar(foo); foo = None
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
tree_unflatten = torch.utils._pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
tree_unflatten = torch.utils.pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
return tree_unflatten""",
)

View File

@ -9,6 +9,7 @@ import sys
import time
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import auto
from typing import Any, NamedTuple, Optional
@ -22,6 +23,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
@ -838,12 +840,13 @@ class TestPythonPytree(TestCase):
script = """
import sys
import torch
import torch.utils._pytree
assert "torch.utils.pytree" in sys.modules
assert "torch.utils._pytree" in sys.modules
if "torch.utils._cxx_pytree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
if "optree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import optree")
if not torch.utils.pytree.PYTORCH_USE_CXX_PYTREE:
if "torch.utils._cxx_pytree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
if "optree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import optree")
"""
try:
subprocess.check_output(
@ -1490,8 +1493,23 @@ class TestCxxPytree(TestCase):
if IS_FBCODE:
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
def assertEqualSpecs(
self,
spec1,
spec2,
msg: str | Callable[[str], str] | None = None,
):
if TEST_WITH_TORCHDYNAMO:
# The Dynamo polyfill returns a pure Python class for PyTreeSpec.
# So we compare the type names and reprs instead because the types
# themselves won't be equal.
self.assertEqual(type(spec1).__name__, type(spec2).__name__, msg=msg)
self.assertEqual(repr(spec1), repr(spec2), msg=msg)
else:
self.assertEqual(spec1, spec2, msg=msg)
def test_treespec_equality(self):
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
self.assertEqualSpecs(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
def test_treespec_repr(self):
# Check that it looks sane
@ -1521,16 +1539,11 @@ class TestCxxPytree(TestCase):
],
)
def test_pytree_serialize(self, spec):
self.assertEqual(
spec,
cxx_pytree.tree_structure(
cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
),
)
serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqualSpecs(roundtrip_spec, spec)
def test_pytree_serialize_namedtuple(self):
python_pytree._register_namedtuple(
@ -1556,7 +1569,7 @@ class TestCxxPytree(TestCase):
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
self.assertEqualSpecs(roundtrip_spec, spec)
class LocalDummyType:
def __init__(self, x, y):
@ -1572,7 +1585,7 @@ class TestCxxPytree(TestCase):
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
self.assertEqualSpecs(roundtrip_spec, spec)
instantiate_parametrized_tests(TestGenericPytree)

View File

@ -2790,6 +2790,7 @@ if TYPE_CHECKING:
_inductor as _inductor,
_subclasses as _subclasses,
onnx as onnx,
pytree as pytree,
)
else:
@ -2799,6 +2800,7 @@ else:
"_export",
# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
"onnx",
"pytree",
}
def __getattr__(name):

View File

@ -1,3 +1,5 @@
# Owner(s): ["module: pytree"]
"""
Python polyfills for torch.utils.pytree
"""
@ -7,7 +9,6 @@ from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING, TypeVar
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
@ -18,7 +19,7 @@ from ..decorators import substitute_in_graph
if TYPE_CHECKING:
import builtins
from collections.abc import Callable, Iterable, Mapping
from typing_extensions import Self
from typing_extensions import Self, TypeIs
__all__: list[str] = []
@ -352,8 +353,10 @@ if python_pytree._cxx_pytree_dynamo_traceable:
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
def _is_pytreespec_instance(
obj: Any, /
) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_leaf,
@ -555,10 +558,13 @@ if python_pytree._cxx_pytree_dynamo_traceable:
)
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
if not _is_pytreespec_instance(leaves):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
# Allow passing the PyTreeSpec instance as the first argument
leaves, treespec = treespec, leaves
return treespec.unflatten(leaves)
__all__ += ["tree_unflatten"]

View File

@ -3452,6 +3452,7 @@ MOD_INLINELIST = [
"torch.utils._python_dispatch",
"torch.utils._pytree",
"torch.utils.hooks",
"torch.utils.pytree",
]
assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST
MOD_INLINELIST = set(MOD_INLINELIST)

View File

@ -39,13 +39,13 @@ from torch.utils._pytree import (
_register_pytree_node,
Context,
FlattenFunc,
FromDumpableContextFn,
FromDumpableContextFunc,
GetAttrKey,
KeyPath,
keystr,
MappingKey,
SequenceKey,
ToDumpableContextFn,
ToDumpableContextFunc,
tree_flatten_with_path,
UnflattenFunc,
)
@ -485,8 +485,8 @@ def register_dataclass_as_pytree_node(
unflatten_fn: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(cls), (

View File

@ -42,7 +42,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils.pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
%getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
@ -293,7 +293,7 @@ def _swap_module_helper(
%y : [num_users=1] = placeholder[target=y]
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
%tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
%tree_unflatten : [num_users=2] = call_function[target=torch.utils.pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})

109
torch/pytree.py Normal file
View File

@ -0,0 +1,109 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `map` to map a function over all Tensors inside some nested
collection of Tensors and `leaves` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
"""
from __future__ import annotations
from typing import Any as _Any, TYPE_CHECKING as _TYPE_CHECKING
from torch.utils.pytree import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
PyTree,
PyTreeSpec,
register_pytree_node as register_node,
tree_all as all,
tree_all_only as all_only,
tree_any as any,
tree_any_only as any_only,
tree_flatten as flatten,
tree_iter as iter,
tree_leaves as leaves,
tree_map as map,
tree_map_ as map_,
tree_map_only as map_only,
tree_map_only_ as map_only_,
tree_structure as structure,
tree_unflatten as _tree_unflatten,
)
if _TYPE_CHECKING:
from collections.abc import Iterable
__all__ = [
"PyTreeSpec",
"register_node",
"flatten",
"unflatten",
"iter",
"leaves",
"structure",
"map",
"map_",
"map_only",
"map_only_",
"all",
"any",
"all_only",
"any_only",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]
def unflatten(treespec: PyTreeSpec, leaves: Iterable[_Any]) -> PyTree:
"""Reconstruct a pytree from the treespec and the leaves.
The inverse of :func:`flatten`.
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> leaves, treespec = torch.pytree.flatten(tree)
>>> tree == torch.pytree.unflatten(treespec, leaves)
True
.. warning::
This function has a different signature than :func:`torch.utils.pytree.tree_unflatten`.
The ``treespec`` argument comes first to have a better :class:`functools.partial` support:
.. code-block:: python
import functools
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
Args:
treespec (PyTreeSpec): The treespec to reconstruct.
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
number of leaves of the treespec.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
# pyrefly: ignore [bad-argument-type]
return _tree_unflatten(leaves, treespec)

View File

@ -11,6 +11,7 @@ from torch.utils import (
data as data,
deterministic as deterministic,
hooks as hooks,
pytree as pytree,
)
from torch.utils.backend_registration import (
generate_methods_for_privateuse1_backend,

View File

@ -1,3 +1,5 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
@ -13,6 +15,7 @@ collection support for PyTorch APIs.
"""
import functools
import sys
import types
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Optional, overload, TypeAlias, TypeVar, Union
@ -21,13 +24,21 @@ from typing_extensions import deprecated, Self, TypeIs
import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion
from torch.utils._pytree import (
Context,
DumpableContext,
FlattenFunc,
FlattenWithKeysFunc,
FromDumpableContextFunc,
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
KeyEntry,
KeyPath,
PyTree,
ToDumpableContextFunc,
UnflattenFunc,
)
@ -51,8 +62,8 @@ __all__ = [
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"ToDumpableContextFunc",
"FromDumpableContextFunc",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
@ -100,19 +111,8 @@ S = TypeVar("S")
U = TypeVar("U")
R = TypeVar("R")
TreeSpec: TypeAlias = PyTreeSpec
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
KeyPath = tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
OpTreeUnflattenFunc: TypeAlias = Callable[[Context, Iterable[Any]], PyTree]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -129,8 +129,8 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
) -> None:
"""Register a container-like type as pytree node.
@ -197,8 +197,8 @@ def _register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
) -> None:
"""Register a container-like type as pytree node for the C++ pytree only.
@ -248,8 +248,8 @@ def _private_register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
for the C++ pytree only. End-users should use :func:`register_pytree_node`
@ -266,8 +266,21 @@ def _private_register_pytree_node(
)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
def _is_pytreespec_instance(
obj: Any,
/,
) -> TypeIs[Union[TreeSpec, python_pytree.PyTreeSpec]]:
if isinstance(obj, (TreeSpec, python_pytree.PyTreeSpec)):
return True
if "torch._dynamo.polyfills.pytree" in sys.modules:
# The PyTorch Dynamo pytree module is not always available, so we check if it is loaded.
# If the PyTorch Dynamo pytree module is loaded, we can check if the treespec
# is an instance of the PyTorch Dynamo TreeSpec class.
import torch._dynamo.polyfills.pytree as dynamo_pytree
if isinstance(obj, dynamo_pytree.PyTreeSpec):
return True
return False
def treespec_leaf() -> TreeSpec:
@ -394,7 +407,15 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
if not _is_pytreespec_instance(treespec):
if not _is_pytreespec_instance(leaves):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
# Allow passing the PyTreeSpec instance as the first argument
leaves, treespec = treespec, leaves
return treespec.unflatten(leaves)
def tree_iter(
@ -973,7 +994,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
"""Serialize a treespec to a JSON string."""
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
f"Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
@ -994,16 +1015,22 @@ def treespec_loads(serialized: str) -> TreeSpec:
return treespec
class _DummyLeaf:
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")
def __repr__(self) -> str:
return "*"
return "*" # no quotes
_asterisk = _Asterisk()
del _Asterisk
def treespec_pprint(treespec: TreeSpec) -> str:
dummy_tree = tree_unflatten(
[_DummyLeaf() for _ in range(treespec.num_leaves)],
treespec,
)
dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec)
return repr(dummy_tree)

View File

@ -1,3 +1,5 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
@ -20,6 +22,7 @@ import functools
import importlib
import importlib.metadata
import json
import sys
import threading
import types
import warnings
@ -36,23 +39,28 @@ from typing import (
Optional,
overload,
Protocol,
TYPE_CHECKING,
TypeAlias,
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple, Self
from typing_extensions import deprecated, NamedTuple, Self, TypeIs
from torch.torch_version import TorchVersion as _TorchVersion
if TYPE_CHECKING:
import torch.utils._cxx_pytree as cxx_pytree
__all__ = [
"PyTree",
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"ToDumpableContextFunc",
"FromDumpableContextFunc",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
@ -119,17 +127,21 @@ class EnumEncoder(json.JSONEncoder):
return cast(str, super().default(obj))
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
ToStrFunc = Callable[["TreeSpec", list[str]], str]
MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]]
KeyPath = tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
Context: TypeAlias = Any
PyTree: TypeAlias = Any
FlattenFunc: TypeAlias = Callable[[PyTree], tuple[list[Any], Context]]
UnflattenFunc: TypeAlias = Callable[[Iterable[Any], Context], PyTree]
DumpableContext: TypeAlias = Any # Any json dumpable text
ToDumpableContextFunc: TypeAlias = Callable[[Context], DumpableContext]
FromDumpableContextFunc: TypeAlias = Callable[[DumpableContext], Context]
ToDumpableContextFn: TypeAlias = ToDumpableContextFunc
FromDumpableContextFn: TypeAlias = FromDumpableContextFunc
ToStrFunc: TypeAlias = Callable[["TreeSpec", list[str]], str]
MaybeFromStrFunc: TypeAlias = Callable[[str], Optional[tuple[Any, Context, str]]]
KeyPath: TypeAlias = tuple[KeyEntry, ...]
FlattenWithKeysFunc: TypeAlias = Callable[
[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]
]
# A NodeDef holds two callables:
@ -162,8 +174,8 @@ SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
class _SerializeNodeDef(NamedTuple):
typ: type[Any]
serialized_type_name: str
to_dumpable_context: Optional[ToDumpableContextFn]
from_dumpable_context: Optional[FromDumpableContextFn]
to_dumpable_context: Optional[ToDumpableContextFunc]
from_dumpable_context: Optional[FromDumpableContextFunc]
SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
@ -200,8 +212,8 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
) -> None:
"""Register a container-like type as pytree node.
@ -250,9 +262,9 @@ def register_pytree_node(
return
if _cxx_pytree_imported:
from . import _cxx_pytree as cxx
import torch.utils._cxx_pytree as cxx_pytree
cxx._private_register_pytree_node(
cxx_pytree._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
@ -528,8 +540,8 @@ def _register_pytree_node(
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
) -> None:
"""Register a container-like type as pytree node for the Python pytree only.
@ -595,8 +607,8 @@ def _private_register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
@ -1086,7 +1098,9 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
# num_children: the number of children of the root Node (i.e., len(children()))
# is_leaf(): whether the root Node is a leaf
@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False)
class TreeSpec:
class PyTreeSpec:
"""Representing the structure of the pytree."""
type: Any
_context: Context
_children: list[Self]
@ -1165,21 +1179,26 @@ class TreeSpec:
return self._children
def is_leaf(self) -> bool:
"""Test whether the treespec represents a leaf."""
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[Self]:
"""Get all the child treespecs."""
return self._children.copy()
def child(self, index: int) -> Self:
"""Get the child treespec at the given index."""
return self._children[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
"""Flatten the subtrees in ``tree`` up to the structure of this treespec and return a list of subtrees."""
def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None:
if treespec.is_leaf():
subtrees.append(tree)
subtrees.append(node)
return
node_type = _get_node_type(tree)
node_type = _get_node_type(node)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
@ -1188,7 +1207,7 @@ class TreeSpec:
f"expected {treespec.type!r}, but got {node_type!r}.",
)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
children, context = flatten_fn(node)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
@ -1210,10 +1229,10 @@ class TreeSpec:
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(tree) != treespec.num_children:
if len(node) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(tree)}.",
f"expected {treespec.num_children}, but got {len(node)}.",
)
if both_standard_dict:
@ -1225,7 +1244,7 @@ class TreeSpec:
else treespec._context[1]
)
expected_keys = dict_context
got_key_set = set(tree)
got_key_set = set(node)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
@ -1236,11 +1255,11 @@ class TreeSpec:
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [tree[key] for key in expected_keys]
children = [node[key] for key in expected_keys]
else:
# node_type is treespec.type
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
children, context = flatten_fn(node)
if (
node_type is not deque # ignore mismatch of `maxlen` for deque
) and context != treespec._context:
@ -1257,6 +1276,7 @@ class TreeSpec:
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
"""Reconstruct a pytree from the leaves."""
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
@ -1302,7 +1322,7 @@ class TreeSpec:
return hash((node_type, hashable_context, tuple(self._children)))
PyTreeSpec: TypeAlias = TreeSpec
TreeSpec: TypeAlias = PyTreeSpec
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
@ -1364,6 +1384,45 @@ def treespec_dict(
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
def _is_pytreespec_instance(
obj: Any,
) -> TypeIs[Union[TreeSpec, "cxx_pytree.PyTreeSpec"]]:
if isinstance(obj, TreeSpec):
return True
if "torch.utils._cxx_pytree" in sys.modules:
# The C++ pytree module is not always available, so we check if it is loaded.
# If the C++ pytree module is loaded, we can check if the treespec
# is an instance of the C++ TreeSpec class.
import torch.utils._cxx_pytree as cxx_pytree
if isinstance(obj, cxx_pytree.PyTreeSpec):
return True
if "torch._dynamo.polyfills.pytree" in sys.modules:
# The PyTorch Dynamo pytree module is not always available, so we check if it is loaded.
# If the PyTorch Dynamo pytree module is loaded, we can check if the treespec
# is an instance of the PyTorch Dynamo TreeSpec class.
import torch._dynamo.polyfills.pytree as dynamo_pytree
if isinstance(obj, dynamo_pytree.PyTreeSpec):
return True
return False
def _ensure_python_treespec_instance(
treespec: Union[TreeSpec, "cxx_pytree.PyTreeSpec"],
) -> TreeSpec:
if isinstance(treespec, TreeSpec):
return treespec
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
dummy_tree = treespec.unflatten([0] * treespec.num_leaves)
return tree_structure(dummy_tree)
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1394,11 +1453,14 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
"""Given a list of values and a TreeSpec, builds a pytree.
This is the inverse operation of `tree_flatten`.
"""
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
f"instance of TreeSpec but got item of type {type(treespec)}.",
)
if not _is_pytreespec_instance(treespec):
if not _is_pytreespec_instance(leaves):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
# Allow passing the PyTreeSpec instance as the first argument
leaves, treespec = treespec, leaves
return treespec.unflatten(leaves)
@ -1828,34 +1890,30 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[list[Any]]:
if not isinstance(treespec, TreeSpec):
raise AssertionError("treespec must be a TreeSpec")
def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> list[Any]:
result: list[Any] = []
if tree_is_leaf(tree, is_leaf=is_leaf):
return [tree] * treespec.num_leaves
if treespec.is_leaf():
def add_leaves(x: Any, subtree: PyTree) -> None:
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
result.extend([x] * subtreespec.num_leaves)
tree_map_(
add_leaves,
prefix_tree,
full_tree,
is_leaf=is_leaf,
)
return result
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
except ValueError:
return None
node_type = _get_node_type(tree)
if node_type != treespec.type:
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
# Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or context != treespec._context:
return None
# Recursively flatten the children
result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None:
result += flat
else:
return None
return result
@dataclasses.dataclass
@ -1969,11 +2027,7 @@ _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}.",
)
treespec = _ensure_python_treespec_instance(treespec)
if protocol is None:
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
@ -2002,16 +2056,22 @@ def treespec_loads(serialized: str) -> TreeSpec:
)
class _DummyLeaf:
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")
def __repr__(self) -> str:
return "*"
return "*" # no quotes
_asterisk = _Asterisk()
del _Asterisk
def treespec_pprint(treespec: TreeSpec) -> str:
dummy_tree = tree_unflatten(
[_DummyLeaf() for _ in range(treespec.num_leaves)],
treespec,
)
dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec)
return repr(dummy_tree)

View File

@ -0,0 +1,216 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `tree_map` to map a function over all Tensors inside some nested
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
"""
import os as _os
import sys as _sys
from typing import Any as _Any, Optional as _Optional
import torch.utils._pytree as python
from torch.utils._exposed_in import exposed_in as _exposed_in
from torch.utils._pytree import ( # these type aliases are identical in both implementations
FlattenFunc,
FlattenWithKeysFunc,
FromDumpableContextFunc,
PyTree,
ToDumpableContextFunc,
UnflattenFunc,
)
__all__ = [
"PyTreeSpec",
"register_pytree_node",
"tree_flatten",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_structure",
"tree_map",
"tree_map_",
"tree_map_only",
"tree_map_only_",
"tree_all",
"tree_any",
"tree_all_only",
"tree_any_only",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]
# NB: Once this variable is read from the environment, the underlying pytree
# implementation is frozen. It cannot be swapped to another at runtime.
PYTORCH_USE_CXX_PYTREE: bool = _os.getenv("PYTORCH_USE_CXX_PYTREE", "0") not in {
"0",
"",
}
if PYTORCH_USE_CXX_PYTREE:
import torch.utils._cxx_pytree as cxx # noqa: F401
if not python._cxx_pytree_dynamo_traceable:
raise ImportError(
"Cannot import package `optree`. "
"Please install `optree` via `python -m pip install --upgrade optree`. "
"Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`."
)
_sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
if not PYTORCH_USE_CXX_PYTREE:
from torch.utils._pytree import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
PyTreeSpec,
register_pytree_node as _register_pytree_node,
tree_all,
tree_all_only,
tree_any,
tree_any_only,
tree_flatten,
tree_iter,
tree_leaves,
tree_map,
tree_map_,
tree_map_only,
tree_map_only_,
tree_structure,
tree_unflatten,
treespec_pprint,
)
PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc]
else:
from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef]
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
PyTreeSpec,
register_pytree_node as _register_pytree_node,
tree_all,
tree_all_only,
tree_any,
tree_any_only,
tree_flatten,
tree_iter,
tree_leaves,
tree_map,
tree_map_,
tree_map_only,
tree_map_only_,
tree_structure,
tree_unflatten,
treespec_pprint,
)
# Change `__module__` of reexported public APIs to 'torch.utils.pytree'
__func_names = frozenset(
{
"tree_all",
"tree_all_only",
"tree_any",
"tree_any_only",
"tree_flatten",
"tree_iter",
"tree_leaves",
"tree_map",
"tree_map_",
"tree_map_only",
"tree_map_only_",
"tree_structure",
"tree_unflatten",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
}
)
globals().update(
{
name: _exposed_in(__name__)(member)
for name, member in globals().items()
if name in __func_names
}
)
del __func_names, _exposed_in
def register_pytree_node(
cls: type[_Any],
/,
# intentionally use `*_func` over `*_fn` to match annotations
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
) -> None:
"""Register a container-like type as pytree node.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_func (callable): A function to be used during flattening, taking an instance of
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
passed to the ``unflatten_func``.
unflatten_func (callable): A function taking two arguments: the unflattened children, and
the auxiliary data that was returned by ``flatten_func`` and stored in the treespec.
The function should return an instance of ``cls``.
Example::
>>> # xdoctest: +SKIP
>>> from collections import UserList
... class MyList(UserList): pass
>>> # Registry a Python type with lambda functions
... register_pytree_node(
... MyList,
... lambda lst: (list(lst), None),
... lambda children, _: MyList(children),
... )
"""
_register_pytree_node(
cls,
flatten_func,
unflatten_func,
)
def __getattr__(name: str) -> _Any:
if name == "cxx":
# Lazy import
import torch.utils._cxx_pytree as cxx # noqa: F811
_sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
return cxx
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,8 @@
# Owner(s): ["module: pytree"]
from .._cxx_pytree import * # noqa: F403
from .._cxx_pytree import (
__all__ as __all__,
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
KeyPath as KeyPath,
)

View File

@ -0,0 +1,15 @@
# Owner(s): ["module: pytree"]
from .._pytree import * # noqa: F403
from .._pytree import (
__all__ as __all__,
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
arg_tree_leaves as arg_tree_leaves,
BUILTIN_TYPES as BUILTIN_TYPES,
GetAttrKey as GetAttrKey,
KeyEntry as KeyEntry,
KeyPath as KeyPath,
MappingKey as MappingKey,
SequenceKey as SequenceKey,
SUPPORTED_NODES as SUPPORTED_NODES,
)