Compare commits

...

196 Commits

Author SHA1 Message Date
2928f413d1 Update
[ghstack-poisoned]
2025-11-12 21:29:46 +08:00
ec433fea94 Update (base update)
[ghstack-poisoned]
2025-11-12 21:29:46 +08:00
349c7d1a30 Update
[ghstack-poisoned]
2025-11-10 18:01:37 +08:00
4cdc2a78cb Update (base update)
[ghstack-poisoned]
2025-11-10 18:01:37 +08:00
9db0e1357e Update
[ghstack-poisoned]
2025-11-10 16:21:08 +08:00
b0a56fa61d Update (base update)
[ghstack-poisoned]
2025-11-10 16:21:08 +08:00
327b39230e Update
[ghstack-poisoned]
2025-11-10 13:50:28 +08:00
b03e6a5ec0 Update (base update)
[ghstack-poisoned]
2025-11-10 13:50:28 +08:00
aa682fcf71 Update
[ghstack-poisoned]
2025-11-10 00:07:45 +08:00
87f52a9413 Update (base update)
[ghstack-poisoned]
2025-11-10 00:07:45 +08:00
c7a384502b Update
[ghstack-poisoned]
2025-11-08 19:49:54 +08:00
34d2cb35ff Update (base update)
[ghstack-poisoned]
2025-11-08 19:49:54 +08:00
af6d87a61b Update
[ghstack-poisoned]
2025-11-08 09:20:54 +08:00
b72c5df156 Update (base update)
[ghstack-poisoned]
2025-11-08 09:20:54 +08:00
9ebf4d1d17 Update
[ghstack-poisoned]
2025-11-06 01:17:09 +08:00
e6a9af94c9 Update (base update)
[ghstack-poisoned]
2025-11-06 01:17:09 +08:00
017eea3b7f Update
[ghstack-poisoned]
2025-11-04 21:47:46 +08:00
f0aa9ad2ce Update (base update)
[ghstack-poisoned]
2025-11-04 21:47:46 +08:00
ea7ce296b3 Update
[ghstack-poisoned]
2025-11-04 21:33:32 +08:00
d7e0635e61 Update (base update)
[ghstack-poisoned]
2025-11-04 21:33:32 +08:00
cc74263539 Update
[ghstack-poisoned]
2025-11-04 20:18:41 +08:00
e813b773ee Update (base update)
[ghstack-poisoned]
2025-11-04 20:18:41 +08:00
a4d0f125e0 Update
[ghstack-poisoned]
2025-11-04 17:07:29 +08:00
e52ebb0c10 Update (base update)
[ghstack-poisoned]
2025-11-04 17:07:29 +08:00
4bd6eb00c4 Update
[ghstack-poisoned]
2025-11-04 13:21:30 +08:00
2a577c1d91 Update (base update)
[ghstack-poisoned]
2025-11-04 13:02:08 +08:00
3dd66fdc20 Update
[ghstack-poisoned]
2025-11-04 13:02:08 +08:00
c9f3c833ef Update (base update)
[ghstack-poisoned]
2025-11-04 12:53:26 +08:00
e5a26f2430 Update
[ghstack-poisoned]
2025-11-04 12:53:26 +08:00
cd51ee276b Update (base update)
[ghstack-poisoned]
2025-10-29 21:23:30 +08:00
2f75d3aebe Update
[ghstack-poisoned]
2025-10-29 21:23:30 +08:00
95039738e3 Update (base update)
[ghstack-poisoned]
2025-10-29 21:02:54 +08:00
f44a8f5201 Update
[ghstack-poisoned]
2025-10-29 21:02:54 +08:00
d96b71c3ef Update (base update)
[ghstack-poisoned]
2025-10-11 21:35:31 +08:00
3f0985bf89 Update
[ghstack-poisoned]
2025-10-11 21:35:31 +08:00
29d16a10f3 Update (base update)
[ghstack-poisoned]
2025-10-08 22:52:13 +08:00
6676fe538f Update
[ghstack-poisoned]
2025-10-08 22:52:13 +08:00
c374b66c75 Update (base update)
[ghstack-poisoned]
2025-10-08 22:39:24 +08:00
70183368c6 Update
[ghstack-poisoned]
2025-10-08 22:39:24 +08:00
faa50fa6c4 Update (base update)
[ghstack-poisoned]
2025-09-19 18:03:46 +08:00
c601a1ea72 Update
[ghstack-poisoned]
2025-09-19 18:03:46 +08:00
dc53fc2af2 Update (base update)
[ghstack-poisoned]
2025-09-06 11:34:34 +08:00
bbb546f542 Update
[ghstack-poisoned]
2025-09-06 11:34:34 +08:00
26a1088f9f Update (base update)
[ghstack-poisoned]
2025-08-17 16:23:38 +08:00
a2d5216c04 Update
[ghstack-poisoned]
2025-08-17 16:23:38 +08:00
968b72ca2c Update (base update)
[ghstack-poisoned]
2025-08-09 02:51:18 +08:00
9f3385822d Update
[ghstack-poisoned]
2025-08-09 02:51:18 +08:00
54911834c4 Update (base update)
[ghstack-poisoned]
2025-07-31 15:19:09 +08:00
5ae581762b Update
[ghstack-poisoned]
2025-07-31 15:19:09 +08:00
1e1b37ed77 Update (base update)
[ghstack-poisoned]
2025-07-25 20:00:31 +08:00
332e835040 Update
[ghstack-poisoned]
2025-07-25 20:00:31 +08:00
8e2c6ff709 Update (base update)
[ghstack-poisoned]
2025-07-17 15:02:04 +08:00
d05480c236 Update
[ghstack-poisoned]
2025-07-17 15:02:04 +08:00
f589cb4a72 Update (base update)
[ghstack-poisoned]
2025-07-09 19:01:34 +08:00
3e66bd8fa8 Update
[ghstack-poisoned]
2025-07-09 19:01:34 +08:00
7aff2fc214 Update (base update)
[ghstack-poisoned]
2025-07-03 16:24:23 +08:00
369df36d49 Update
[ghstack-poisoned]
2025-07-03 16:24:23 +08:00
41663a247d Update (base update)
[ghstack-poisoned]
2025-06-28 20:59:47 +08:00
b27cc37252 Update
[ghstack-poisoned]
2025-06-28 20:59:47 +08:00
1c337ea84b Update (base update)
[ghstack-poisoned]
2025-06-27 21:27:45 +08:00
80e40a5976 Update
[ghstack-poisoned]
2025-06-27 21:27:45 +08:00
6d00bd774f Update (base update)
[ghstack-poisoned]
2025-06-23 22:51:21 +08:00
fb2b16422d Update
[ghstack-poisoned]
2025-06-23 22:51:21 +08:00
303e7afcd9 Update (base update)
[ghstack-poisoned]
2025-06-18 23:17:48 +08:00
2a9ceb2f1f Update
[ghstack-poisoned]
2025-06-18 23:17:48 +08:00
c2e1972c18 Update (base update)
[ghstack-poisoned]
2025-06-06 19:50:50 +08:00
cead985182 Update
[ghstack-poisoned]
2025-06-06 19:50:50 +08:00
b32f36ce35 Update (base update)
[ghstack-poisoned]
2025-05-31 21:59:59 +08:00
11d7c79cea Update
[ghstack-poisoned]
2025-05-31 21:59:59 +08:00
2c60570864 Update (base update)
[ghstack-poisoned]
2025-05-28 20:43:33 +08:00
979efcb825 Update
[ghstack-poisoned]
2025-05-28 20:43:33 +08:00
faea6584f8 Update (base update)
[ghstack-poisoned]
2025-05-16 11:37:32 +08:00
99836e07fb Update
[ghstack-poisoned]
2025-05-16 11:37:32 +08:00
3d383f42e9 Update (base update)
[ghstack-poisoned]
2025-05-14 20:35:01 +08:00
08d95fe5c4 Update
[ghstack-poisoned]
2025-05-14 20:35:01 +08:00
27ac21550e Update (base update)
[ghstack-poisoned]
2025-05-08 21:19:08 +08:00
e57894677e Update
[ghstack-poisoned]
2025-05-08 21:19:08 +08:00
28c48d60aa Update (base update)
[ghstack-poisoned]
2025-05-04 02:10:44 +08:00
41259dc86e Update
[ghstack-poisoned]
2025-05-04 02:10:44 +08:00
ab474eebfd Update (base update)
[ghstack-poisoned]
2025-05-03 02:34:22 +08:00
113b8306a8 Update
[ghstack-poisoned]
2025-05-03 02:34:22 +08:00
727a1aa849 Update (base update)
[ghstack-poisoned]
2025-05-03 01:14:43 +08:00
dbd47d2dae Update
[ghstack-poisoned]
2025-05-03 01:14:43 +08:00
c659299f60 Update (base update)
[ghstack-poisoned]
2025-05-03 00:45:00 +08:00
ba32ef92a6 Update
[ghstack-poisoned]
2025-05-03 00:45:00 +08:00
4fd958b566 Update (base update)
[ghstack-poisoned]
2025-05-03 00:40:33 +08:00
313bfeea17 Update
[ghstack-poisoned]
2025-05-03 00:40:33 +08:00
f3e5113185 Update (base update)
[ghstack-poisoned]
2025-05-02 02:30:03 +08:00
5710b9d4af Update
[ghstack-poisoned]
2025-05-02 02:30:03 +08:00
e4d59f3a0a Update (base update)
[ghstack-poisoned]
2025-05-02 02:25:06 +08:00
599b39676b Update
[ghstack-poisoned]
2025-05-02 02:25:06 +08:00
cd38ad58a5 Update (base update)
[ghstack-poisoned]
2025-05-02 01:44:37 +08:00
fffb30d445 Update
[ghstack-poisoned]
2025-05-02 01:44:37 +08:00
3e53b1fb27 Update (base update)
[ghstack-poisoned]
2025-05-02 01:39:06 +08:00
c309091454 Update
[ghstack-poisoned]
2025-05-02 01:39:06 +08:00
121f110b83 Update (base update)
[ghstack-poisoned]
2025-04-26 11:34:39 +08:00
bfe4839177 Update
[ghstack-poisoned]
2025-04-26 11:34:39 +08:00
7a962a2f0d Update (base update)
[ghstack-poisoned]
2025-04-23 21:35:46 +08:00
ad0668cbb8 Update
[ghstack-poisoned]
2025-04-23 21:35:46 +08:00
f6795a4922 Update (base update)
[ghstack-poisoned]
2025-04-15 22:19:49 +08:00
0d8d2a0360 Update
[ghstack-poisoned]
2025-04-15 22:19:49 +08:00
f98737b13a Update (base update)
[ghstack-poisoned]
2025-04-15 22:12:31 +08:00
898c2037ed Update
[ghstack-poisoned]
2025-04-15 22:12:31 +08:00
f643f2e78b Update (base update)
[ghstack-poisoned]
2025-04-15 22:10:41 +08:00
97cf576c2c Update
[ghstack-poisoned]
2025-04-15 22:10:41 +08:00
d5e568d6e0 Update (base update)
[ghstack-poisoned]
2025-04-15 22:03:25 +08:00
c9607615a4 Update
[ghstack-poisoned]
2025-04-15 22:03:25 +08:00
2549707053 Update (base update)
[ghstack-poisoned]
2025-04-11 19:04:56 +08:00
aa53182976 Update
[ghstack-poisoned]
2025-04-11 19:04:56 +08:00
3b8b6ef6fb Update (base update)
[ghstack-poisoned]
2025-04-11 18:35:20 +08:00
0625bbf0c7 Update
[ghstack-poisoned]
2025-04-11 18:35:20 +08:00
57f2575735 Update (base update)
[ghstack-poisoned]
2025-04-11 18:27:39 +08:00
ddb82b6b96 Update
[ghstack-poisoned]
2025-04-11 18:27:39 +08:00
d1aba50677 Update (base update)
[ghstack-poisoned]
2025-04-11 18:16:46 +08:00
760f4fb105 Update
[ghstack-poisoned]
2025-04-11 18:16:46 +08:00
46bb41bd37 Update (base update)
[ghstack-poisoned]
2025-04-10 17:25:15 +08:00
98984e1561 Update
[ghstack-poisoned]
2025-04-10 17:25:15 +08:00
d90e80dd35 Update (base update)
[ghstack-poisoned]
2025-04-07 22:41:39 +08:00
246b2fd7a0 Update
[ghstack-poisoned]
2025-04-07 22:41:39 +08:00
8c13a8323a Update (base update)
[ghstack-poisoned]
2025-04-05 23:26:58 +08:00
1fa721545a Update
[ghstack-poisoned]
2025-04-05 23:26:58 +08:00
f649b7bfbd Update (base update)
[ghstack-poisoned]
2025-04-03 23:11:59 +08:00
f0e9ee0bdc Update
[ghstack-poisoned]
2025-04-03 23:11:59 +08:00
45f34b99a9 Update (base update)
[ghstack-poisoned]
2025-04-03 22:22:53 +08:00
6e6343130f Update
[ghstack-poisoned]
2025-04-03 22:22:53 +08:00
53f09a5136 Update (base update)
[ghstack-poisoned]
2025-04-03 21:58:28 +08:00
5df659dcf8 Update
[ghstack-poisoned]
2025-04-03 21:58:28 +08:00
2e6d995297 Update
[ghstack-poisoned]
2025-04-02 00:14:59 +08:00
aa65799ee0 Update (base update)
[ghstack-poisoned]
2025-04-02 00:14:58 +08:00
02a382b7be Update
[ghstack-poisoned]
2025-03-21 00:08:15 +08:00
6abac60294 Update (base update)
[ghstack-poisoned]
2025-03-21 00:08:15 +08:00
4037b4fc22 Update
[ghstack-poisoned]
2025-03-14 12:47:28 +08:00
af5bc4e801 Update (base update)
[ghstack-poisoned]
2025-03-14 12:47:28 +08:00
108d7f193a Update
[ghstack-poisoned]
2025-03-13 04:41:40 +08:00
a8fb34cae5 Update (base update)
[ghstack-poisoned]
2025-03-13 04:41:40 +08:00
7b122690be Update
[ghstack-poisoned]
2025-03-07 18:09:20 +08:00
4caf34ab53 Update (base update)
[ghstack-poisoned]
2025-03-07 18:09:20 +08:00
3409a1e033 Update
[ghstack-poisoned]
2025-03-07 03:57:21 +08:00
66681eea1b Update (base update)
[ghstack-poisoned]
2025-03-07 03:57:21 +08:00
b7838168c3 Update
[ghstack-poisoned]
2025-03-07 03:19:28 +08:00
26b6913b3d Update (base update)
[ghstack-poisoned]
2025-03-07 03:19:28 +08:00
ca192e08bd Update
[ghstack-poisoned]
2025-03-06 21:45:27 +08:00
88a13fdc94 Update (base update)
[ghstack-poisoned]
2025-03-06 21:45:27 +08:00
d099b63055 Update
[ghstack-poisoned]
2025-03-05 20:36:42 +08:00
5fd8eb6fb8 Update (base update)
[ghstack-poisoned]
2025-03-05 20:36:41 +08:00
a4a2c1cffb Update
[ghstack-poisoned]
2025-03-05 20:19:47 +08:00
400df72bda Update (base update)
[ghstack-poisoned]
2025-03-05 20:19:47 +08:00
33f1963cc4 Update
[ghstack-poisoned]
2025-03-05 20:14:48 +08:00
ac68c29b4c Update (base update)
[ghstack-poisoned]
2025-03-05 20:14:48 +08:00
2832a51c54 Update
[ghstack-poisoned]
2025-03-05 01:46:28 +08:00
3b3af30466 Update (base update)
[ghstack-poisoned]
2025-03-05 00:41:45 +08:00
49df7a7617 Update
[ghstack-poisoned]
2025-03-05 00:41:45 +08:00
92176bfdff Update (base update)
[ghstack-poisoned]
2025-03-04 22:37:24 +08:00
2f41576c02 Update
[ghstack-poisoned]
2025-03-04 22:37:24 +08:00
0b4314efe2 Update (base update)
[ghstack-poisoned]
2025-03-04 19:05:38 +08:00
f092c0e8ce Update
[ghstack-poisoned]
2025-03-04 19:05:38 +08:00
0c224957b6 Update (base update)
[ghstack-poisoned]
2025-03-04 17:15:30 +08:00
a912652a09 Update
[ghstack-poisoned]
2025-03-04 17:15:30 +08:00
afb24fc9c1 Update (base update)
[ghstack-poisoned]
2025-03-04 11:43:02 +08:00
7017bcda6e Update
[ghstack-poisoned]
2025-03-04 11:43:02 +08:00
98b5a2fc77 Update (base update)
[ghstack-poisoned]
2025-03-04 04:46:17 +08:00
c2e569992b Update
[ghstack-poisoned]
2025-03-04 04:46:17 +08:00
4cdcd94061 Update (base update)
[ghstack-poisoned]
2025-03-04 03:31:33 +08:00
b2727a655f Update
[ghstack-poisoned]
2025-03-04 03:31:33 +08:00
3ae5f28df9 Update (base update)
[ghstack-poisoned]
2025-03-04 03:09:37 +08:00
130df3a1d6 Update
[ghstack-poisoned]
2025-03-04 03:09:37 +08:00
ffe60c3005 Update (base update)
[ghstack-poisoned]
2025-03-04 02:44:32 +08:00
f03956e5ae Update
[ghstack-poisoned]
2025-03-04 02:44:32 +08:00
782d543bf7 Update (base update)
[ghstack-poisoned]
2025-03-01 21:29:32 +08:00
a5d05d3d4e Update
[ghstack-poisoned]
2025-03-01 21:29:32 +08:00
7915feda28 Update (base update)
[ghstack-poisoned]
2025-03-01 19:17:04 +08:00
7c8aeffc4a Update
[ghstack-poisoned]
2025-03-01 19:17:04 +08:00
0518f254ed Update
[ghstack-poisoned]
2025-03-01 03:09:57 +08:00
ef3adf6eac Update
[ghstack-poisoned]
2025-03-01 02:56:44 +08:00
c7e5b56a7d Update
[ghstack-poisoned]
2025-03-01 01:36:03 +08:00
86001ed575 Update (base update)
[ghstack-poisoned]
2025-03-01 00:37:37 +08:00
b8eadce989 Update
[ghstack-poisoned]
2025-03-01 00:37:37 +08:00
06a7899877 Update (base update)
[ghstack-poisoned]
2025-03-01 00:36:38 +08:00
4b8ceddea7 Update
[ghstack-poisoned]
2025-03-01 00:36:38 +08:00
d55bb26dc5 Update (base update)
[ghstack-poisoned]
2025-03-01 00:24:08 +08:00
c93a280d5a Update
[ghstack-poisoned]
2025-03-01 00:24:08 +08:00
0e45af0a50 Update (base update)
[ghstack-poisoned]
2025-03-01 00:03:44 +08:00
f62063d25c Update
[ghstack-poisoned]
2025-03-01 00:03:44 +08:00
06d1a7fa0b Update (base update)
[ghstack-poisoned]
2025-02-28 23:53:49 +08:00
00a6482df6 Update
[ghstack-poisoned]
2025-02-28 23:53:49 +08:00
1d6e7ada97 Update (base update)
[ghstack-poisoned]
2025-02-28 23:45:41 +08:00
308f6a05ac Update
[ghstack-poisoned]
2025-02-28 23:45:41 +08:00
4e773f0037 Update
[ghstack-poisoned]
2025-02-28 23:39:42 +08:00
672915aece Update (base update)
[ghstack-poisoned]
2025-02-28 22:39:24 +08:00
08affe6664 Update
[ghstack-poisoned]
2025-02-28 22:39:24 +08:00
c36201b8b2 Update
[ghstack-poisoned]
2025-02-28 22:11:12 +08:00
347e4a1001 Update (base update)
[ghstack-poisoned]
2025-02-28 20:16:12 +08:00
c5f436b865 Update
[ghstack-poisoned]
2025-02-28 20:16:12 +08:00
51232f8496 Update
[ghstack-poisoned]
2025-02-28 19:35:17 +08:00
a9a875cb5c Update (base update)
[ghstack-poisoned]
2025-02-28 18:51:42 +08:00
3ae649453b Update
[ghstack-poisoned]
2025-02-28 18:51:42 +08:00
24 changed files with 720 additions and 133 deletions

View File

@ -195,6 +195,7 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/torch/utils/_pytree.py @XuehaiPan
/torch/utils/_cxx_pytree.py @XuehaiPan
/torch/utils/pytree/ @XuehaiPan
/torch/pytree.py @XuehaiPan
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
# Relating to libtorch ABI

View File

@ -60,6 +60,7 @@ torch.special <special>
torch.overrides
torch.nativert <nativert>
torch.package <package>
torch.pytree <pytree>
profiler
nn.init
nn.attention
@ -77,6 +78,7 @@ sparse
storage
torch.testing <testing>
torch.utils <utils>
torch.utils.pytree
torch.utils.benchmark <benchmark_utils>
torch.utils.checkpoint <checkpoint>
torch.utils.cpp_extension <cpp_extension>

7
docs/source/pytree.rst Normal file
View File

@ -0,0 +1,7 @@
torch.pytree
============
.. currentmodule:: torch.pytree
.. automodule:: torch.pytree
:members:

View File

@ -0,0 +1,7 @@
torch.utils.pytree
==================
.. currentmodule:: torch.utils.pytree
.. automodule:: torch.utils.pytree
:members:

View File

@ -29,6 +29,7 @@ files =
benchmarks/instruction_counts,
tools,
torch/profiler/_memory_profiler.py,
torch/utils/pytree/__init__.py,
torch/utils/_pytree.py,
torch/utils/_cxx_pytree.py,
torch/utils/benchmark/utils/common.py,

View File

@ -687,6 +687,28 @@
"kineto_available",
"record_function"
],
"torch.pytree": [
"PyTreeSpec",
"register_node",
"all",
"all_only",
"any",
"any_only",
"flatten",
"iter",
"leaves",
"map",
"map_",
"map_only",
"map_only_",
"structure",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance"
],
"torch.quantization": [
"ABC",
"DeQuantStub",

View File

@ -147,8 +147,8 @@ class GraphModule(torch.nn.Module):
t: "f32[10]" = l_x_ + l_y_
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
trace_point_tensor_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_input_spec
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
return (res,)
""", # NOQA: B950

View File

@ -40,6 +40,7 @@ import torch._inductor.test_case
import torch.onnx.operators
import torch.utils._pytree as python_pytree
import torch.utils.cpp_extension
import torch.utils.pytree as generic_pytree
from torch import Tensor
from torch._C import FileCheck
from torch._dynamo import allow_in_graph
@ -104,6 +105,7 @@ from torch.testing._internal.jit_utils import JitTestCase
pytree_modules = {
"generic": generic_pytree,
"python": python_pytree,
}
if python_pytree._cxx_pytree_dynamo_traceable:

View File

@ -8296,7 +8296,7 @@ graph():
%to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {})
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
%tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
%tree_unflatten : [num_users=1] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
return tree_unflatten""",
)

View File

@ -249,11 +249,11 @@ def forward(self, x, y):
_spec_0 = self._spec_0
_spec_1 = self._spec_1
_spec_4 = self._spec_4
tree_flatten = torch.utils._pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
tree_flatten = torch.utils.pytree.python.tree_flatten((x_1, y_1)); x_1 = y_1 = None
getitem = tree_flatten[0]; tree_flatten = None
x = getitem[0]
y = getitem[1]; getitem = None
tree_unflatten_1 = torch.utils._pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
tree_unflatten_1 = torch.utils.pytree.python.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
getitem_2 = getitem_1[0]
getitem_3 = getitem_1[1]; getitem_1 = None
@ -261,7 +261,7 @@ def forward(self, x, y):
bar = self.bar(foo); foo = None
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
tree_unflatten = torch.utils._pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
tree_unflatten = torch.utils.pytree.python.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
return tree_unflatten""",
)
@ -324,7 +324,7 @@ def forward(self, x, y):
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
_spec_0 = self._spec_0
_spec_3 = self._spec_3
tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
tree_unflatten = torch.utils.pytree.python.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
getitem = tree_unflatten[0]; tree_unflatten = None
getitem_1 = getitem[0]
getitem_2 = getitem[1]; getitem = None

View File

@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
@ -52,6 +53,14 @@ class GlobalDummyType:
self.x = x
self.y = y
def __eq__(self, other):
if not isinstance(other, GlobalDummyType):
return NotImplemented
return self.x == other.x and self.y == other.y
def __hash__(self):
return hash((self.x, self.y))
cxx_pytree.register_pytree_node(
GlobalDummyType,
@ -156,6 +165,44 @@ class TestGenericPytree(TestCase):
),
)
@parametrize(
"modulename",
[
subtest("python", name="py"),
*([subtest("cxx", name="cxx")] if not IS_FBCODE else []),
],
)
def test_public_api_import(self, modulename):
for use_cxx_pytree in [None, "", "0", *(["1"] if not IS_FBCODE else [])]:
env = os.environ.copy()
if use_cxx_pytree is not None:
env["PYTORCH_USE_CXX_PYTREE"] = str(use_cxx_pytree)
else:
env.pop("PYTORCH_USE_CXX_PYTREE", None)
for statement in (
f"import torch.utils.pytree.{modulename}",
f"from torch.utils.pytree import {modulename}",
f"from torch.utils.pytree.{modulename} import tree_map",
f"import torch.utils.pytree; torch.utils.pytree.{modulename}",
f"import torch.utils.pytree; torch.utils.pytree.{modulename}.tree_map",
):
try:
subprocess.check_output(
[sys.executable, "-c", statement],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
env=env,
)
except subprocess.CalledProcessError as e:
self.fail(
msg=(
f"Subprocess exception while attempting to run statement `{statement}`: "
+ e.output.decode("utf-8")
)
)
@parametrize_pytree_module
def test_register_pytree_node(self, pytree):
class MyDict(UserDict):
@ -838,12 +885,13 @@ class TestPythonPytree(TestCase):
script = """
import sys
import torch
import torch.utils._pytree
assert "torch.utils.pytree" in sys.modules
assert "torch.utils._pytree" in sys.modules
if "torch.utils._cxx_pytree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
if "optree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import optree")
if not torch.utils.pytree.PYTORCH_USE_CXX_PYTREE:
if "torch.utils._cxx_pytree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
if "optree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import optree")
"""
try:
subprocess.check_output(
@ -1490,6 +1538,25 @@ class TestCxxPytree(TestCase):
if IS_FBCODE:
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
def assertEqual(self, x, y, *args, **kwargs):
x_typename, y_typename = type(x).__name__, type(y).__name__
if not ("treespec" in x_typename.lower() or "treespec" in y_typename.lower()):
super().assertEqual(x, y, *args, **kwargs)
# The Dynamo polyfill returns a polyfilled Python class for C++ PyTreeSpec instead of the
# C++ class. So we compare the type names and reprs instead because the types themselves
# won't be equal.
super().assertEqual(x_typename, y_typename, *args, **kwargs)
if not TEST_WITH_TORCHDYNAMO or type(x) is type(y):
super().assertEqual(x, y, *args, **kwargs)
else:
super().assertEqual(
x.unflatten(range(x.num_leaves)),
y.unflatten(range(y.num_leaves)),
*args,
**kwargs,
)
def test_treespec_equality(self):
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
@ -1530,7 +1597,9 @@ class TestCxxPytree(TestCase):
serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
def test_pytree_serialize_namedtuple(self):
python_pytree._register_namedtuple(
@ -1563,6 +1632,14 @@ class TestCxxPytree(TestCase):
self.x = x
self.y = y
def __eq__(self, other):
if not isinstance(other, LocalDummyType):
return NotImplemented
return self.x == other.x and self.y == other.y
def __hash__(self):
return hash((self.x, self.y))
cxx_pytree.register_pytree_node(
LocalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),

View File

@ -2794,6 +2794,7 @@ if TYPE_CHECKING:
_inductor as _inductor,
_subclasses as _subclasses,
onnx as onnx,
pytree as pytree,
)
else:
@ -2803,6 +2804,7 @@ else:
"_export",
# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
"onnx",
"pytree",
}
def __getattr__(name):

View File

@ -1,3 +1,5 @@
# Owner(s): ["module: pytree"]
"""
Python polyfills for torch.utils.pytree
"""
@ -23,6 +25,7 @@ from optree import (
)
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
from ..decorators import substitute_in_graph
@ -427,8 +430,8 @@ class PyTreeSpec:
return self._unflatten_func(self._metadata, subtrees)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
@substitute_in_graph( # type: ignore[arg-type]
@ -698,7 +701,7 @@ def tree_structure(
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)

View File

@ -3454,6 +3454,7 @@ MOD_INLINELIST = [
"torch.utils._python_dispatch",
"torch.utils._pytree",
"torch.utils.hooks",
"torch.utils.pytree",
]
assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST
MOD_INLINELIST = set(MOD_INLINELIST)

View File

@ -39,13 +39,13 @@ from torch.utils._pytree import (
_register_pytree_node,
Context,
FlattenFunc,
FromDumpableContextFn,
FromDumpableContextFunc,
GetAttrKey,
KeyPath,
keystr,
MappingKey,
SequenceKey,
ToDumpableContextFn,
ToDumpableContextFunc,
tree_flatten_with_path,
UnflattenFunc,
)
@ -485,8 +485,8 @@ def register_dataclass_as_pytree_node(
unflatten_fn: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(cls), (

View File

@ -42,7 +42,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
%getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
@ -293,7 +293,7 @@ def _swap_module_helper(
%y : [num_users=1] = placeholder[target=y]
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
%tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
%tree_unflatten : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})

View File

@ -88,7 +88,7 @@ _register_custom_builtin("NoneType", "NoneType = type(None)", type(None))
_register_custom_builtin("torch", "import torch", torch)
_register_custom_builtin("device", "from torch import device", torch.device)
_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree)
_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree)
_register_custom_builtin("pytree", "import torch.utils.pytree.python as pytree", pytree)
def _is_magic(x: str) -> bool:

109
torch/pytree.py Normal file
View File

@ -0,0 +1,109 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `map` to map a function over all Tensors inside some nested
collection of Tensors and `leaves` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
"""
from __future__ import annotations
from typing import Any as _Any, TYPE_CHECKING as _TYPE_CHECKING
from torch.utils.pytree import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
PyTree,
PyTreeSpec,
register_pytree_node as register_node,
tree_all as all,
tree_all_only as all_only,
tree_any as any,
tree_any_only as any_only,
tree_flatten as flatten,
tree_iter as iter,
tree_leaves as leaves,
tree_map as map,
tree_map_ as map_,
tree_map_only as map_only,
tree_map_only_ as map_only_,
tree_structure as structure,
tree_unflatten as _tree_unflatten,
)
if _TYPE_CHECKING:
from collections.abc import Iterable
__all__ = [
"PyTreeSpec",
"register_node",
"flatten",
"unflatten",
"iter",
"leaves",
"structure",
"map",
"map_",
"map_only",
"map_only_",
"all",
"any",
"all_only",
"any_only",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]
def unflatten(treespec: PyTreeSpec, leaves: Iterable[_Any]) -> PyTree:
"""Reconstruct a pytree from the treespec and the leaves.
The inverse of :func:`flatten`.
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> leaves, treespec = torch.pytree.flatten(tree)
>>> tree == torch.pytree.unflatten(treespec, leaves)
True
.. warning::
This function has a different signature than :func:`torch.utils.pytree.tree_unflatten`.
The ``treespec`` argument comes first to have a better :class:`functools.partial` support:
.. code-block:: python
import functools
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
tree1 = unflatten_fn(leaves1)
tree2 = unflatten_fn(leaves2)
Args:
treespec (PyTreeSpec): The treespec to reconstruct.
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
number of leaves of the treespec.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
# pyrefly: ignore [bad-argument-type]
return _tree_unflatten(leaves, treespec)

View File

@ -11,6 +11,7 @@ from torch.utils import (
data as data,
deterministic as deterministic,
hooks as hooks,
pytree as pytree,
)
from torch.utils.backend_registration import (
generate_methods_for_privateuse1_backend,

View File

@ -1,3 +1,5 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
@ -13,6 +15,7 @@ collection support for PyTorch APIs.
"""
import functools
import sys
import types
from collections.abc import Callable, Iterable, Mapping
from typing import Any, overload, TypeAlias, TypeVar, Union
@ -21,13 +24,21 @@ from typing_extensions import deprecated, Self, TypeIs
import torch.utils._pytree as python_pytree
from torch.torch_version import TorchVersion as _TorchVersion
from torch.utils._pytree import (
Context,
DumpableContext,
FlattenFunc,
FlattenWithKeysFunc,
FromDumpableContextFunc,
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
KeyEntry,
KeyPath,
PyTree,
ToDumpableContextFunc,
UnflattenFunc,
)
@ -51,8 +62,8 @@ __all__ = [
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"ToDumpableContextFunc",
"FromDumpableContextFunc",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
@ -88,6 +99,9 @@ __all__ = [
]
__name__ = "torch.utils.pytree.cxx" # sets the __module__ attribute of all functions in this module
# In-tree installation may have VCS-based versioning. Update the previous static version.
python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]
@ -100,19 +114,8 @@ S = TypeVar("S")
U = TypeVar("U")
R = TypeVar("R")
TreeSpec: TypeAlias = PyTreeSpec
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
KeyPath = tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
OpTreeUnflattenFunc: TypeAlias = Callable[[Context, Iterable[Any]], PyTree]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -129,8 +132,8 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
to_dumpable_context: ToDumpableContextFunc | None = None,
from_dumpable_context: FromDumpableContextFunc | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node.
@ -197,8 +200,8 @@ def _register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
to_dumpable_context: ToDumpableContextFunc | None = None,
from_dumpable_context: FromDumpableContextFunc | None = None,
) -> None:
"""Register a container-like type as pytree node for the C++ pytree only.
@ -248,8 +251,8 @@ def _private_register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
to_dumpable_context: ToDumpableContextFunc | None = None,
from_dumpable_context: FromDumpableContextFunc | None = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
for the C++ pytree only. End-users should use :func:`register_pytree_node`
@ -266,8 +269,21 @@ def _private_register_pytree_node(
)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
def _is_pytreespec_instance(
obj: Any,
/,
) -> TypeIs[Union[TreeSpec, python_pytree.PyTreeSpec]]:
if isinstance(obj, (TreeSpec, python_pytree.PyTreeSpec)):
return True
if "torch._dynamo.polyfills.pytree" in sys.modules:
# The PyTorch Dynamo pytree module is not always available, so we check if it is loaded.
# If the PyTorch Dynamo pytree module is loaded, we can check if the treespec
# is an instance of the PyTorch Dynamo TreeSpec class.
import torch._dynamo.polyfills.pytree as dynamo_pytree
if isinstance(obj, dynamo_pytree.PyTreeSpec):
return True
return False
def treespec_leaf() -> TreeSpec:
@ -394,7 +410,15 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
if not _is_pytreespec_instance(treespec):
if not _is_pytreespec_instance(leaves):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
# Allow passing the PyTreeSpec instance as the first argument
leaves, treespec = treespec, leaves
return treespec.unflatten(leaves)
def tree_iter(
@ -959,8 +983,9 @@ def _broadcast_to_and_flatten(
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any] | None:
if not _is_pytreespec_instance(treespec):
raise AssertionError(
f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}"
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
@ -973,7 +998,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str:
"""Serialize a treespec to a JSON string."""
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
@ -994,16 +1019,22 @@ def treespec_loads(serialized: str) -> TreeSpec:
return treespec
class _DummyLeaf:
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")
def __repr__(self) -> str:
return "*"
return "*" # no quotes
_asterisk = _Asterisk()
del _Asterisk
def treespec_pprint(treespec: TreeSpec) -> str:
dummy_tree = tree_unflatten(
[_DummyLeaf() for _ in range(treespec.num_leaves)],
treespec,
)
dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec)
return repr(dummy_tree)

View File

@ -1,3 +1,5 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
@ -20,6 +22,7 @@ import functools
import importlib
import importlib.metadata
import json
import sys
import threading
import types
import warnings
@ -35,23 +38,28 @@ from typing import (
NoReturn,
overload,
Protocol,
TYPE_CHECKING,
TypeAlias,
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple, Self
from typing_extensions import deprecated, NamedTuple, Self, TypeIs
from torch.torch_version import TorchVersion as _TorchVersion
if TYPE_CHECKING:
import torch.utils._cxx_pytree as cxx_pytree
__all__ = [
"PyTree",
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"ToDumpableContextFunc",
"FromDumpableContextFunc",
"PyTreeSpec",
"TreeSpec",
"LeafSpec",
@ -87,6 +95,9 @@ __all__ = [
]
__name__ = "torch.utils.pytree.python" # sets the __module__ attribute of all functions in this module
T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
@ -118,17 +129,21 @@ class EnumEncoder(json.JSONEncoder):
return cast(str, super().default(obj))
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
ToStrFunc = Callable[["TreeSpec", list[str]], str]
MaybeFromStrFunc = Callable[[str], tuple[Any, Context, str] | None]
KeyPath = tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
Context: TypeAlias = Any
PyTree: TypeAlias = Any
FlattenFunc: TypeAlias = Callable[[PyTree], tuple[list[Any], Context]]
UnflattenFunc: TypeAlias = Callable[[Iterable[Any], Context], PyTree]
DumpableContext: TypeAlias = Any # Any json dumpable text
ToDumpableContextFunc: TypeAlias = Callable[[Context], DumpableContext]
FromDumpableContextFunc: TypeAlias = Callable[[DumpableContext], Context]
ToDumpableContextFn: TypeAlias = ToDumpableContextFunc
FromDumpableContextFn: TypeAlias = FromDumpableContextFunc
ToStrFunc: TypeAlias = Callable[["TreeSpec", list[str]], str]
MaybeFromStrFunc: TypeAlias = Callable[[str], tuple[Any, Context, str] | None]
KeyPath: TypeAlias = tuple[KeyEntry, ...]
FlattenWithKeysFunc: TypeAlias = Callable[
[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]
]
# A NodeDef holds two callables:
@ -161,8 +176,8 @@ SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
class _SerializeNodeDef(NamedTuple):
typ: type[Any]
serialized_type_name: str
to_dumpable_context: ToDumpableContextFn | None
from_dumpable_context: FromDumpableContextFn | None
to_dumpable_context: ToDumpableContextFunc | None
from_dumpable_context: FromDumpableContextFunc | None
SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
@ -199,8 +214,8 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
to_dumpable_context: ToDumpableContextFunc | None = None,
from_dumpable_context: FromDumpableContextFunc | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node.
@ -249,9 +264,9 @@ def register_pytree_node(
return
if _cxx_pytree_imported:
from . import _cxx_pytree as cxx
import torch.utils._cxx_pytree as cxx_pytree
cxx._private_register_pytree_node(
cxx_pytree._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
@ -527,8 +542,8 @@ def _register_pytree_node(
maybe_from_str_fn: MaybeFromStrFunc | None = None, # deprecated
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
to_dumpable_context: ToDumpableContextFunc | None = None,
from_dumpable_context: FromDumpableContextFunc | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node for the Python pytree only.
@ -594,8 +609,8 @@ def _private_register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
to_dumpable_context: ToDumpableContextFunc | None = None,
from_dumpable_context: FromDumpableContextFunc | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
@ -1085,7 +1100,9 @@ def _is_leaf(tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None) -> b
# num_children: the number of children of the root Node (i.e., len(children()))
# is_leaf(): whether the root Node is a leaf
@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False)
class TreeSpec:
class PyTreeSpec:
"""Representing the structure of the pytree."""
type: Any
_context: Context
_children: list[Self]
@ -1164,21 +1181,26 @@ class TreeSpec:
return self._children
def is_leaf(self) -> bool:
"""Test whether the treespec represents a leaf."""
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[Self]:
"""Get all the child treespecs."""
return self._children.copy()
def child(self, index: int) -> Self:
"""Get the child treespec at the given index."""
return self._children[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
"""Flatten the subtrees in ``tree`` up to the structure of this treespec and return a list of subtrees."""
def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None:
if treespec.is_leaf():
subtrees.append(tree)
subtrees.append(node)
return
node_type = _get_node_type(tree)
node_type = _get_node_type(node)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
@ -1187,7 +1209,7 @@ class TreeSpec:
f"expected {treespec.type!r}, but got {node_type!r}.",
)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
children, context = flatten_fn(node)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
@ -1209,10 +1231,10 @@ class TreeSpec:
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(tree) != treespec.num_children:
if len(node) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(tree)}.",
f"expected {treespec.num_children}, but got {len(node)}.",
)
if both_standard_dict:
@ -1224,7 +1246,7 @@ class TreeSpec:
else treespec._context[1]
)
expected_keys = dict_context
got_key_set = set(tree)
got_key_set = set(node)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
@ -1235,11 +1257,11 @@ class TreeSpec:
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [tree[key] for key in expected_keys]
children = [node[key] for key in expected_keys]
else:
# node_type is treespec.type
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
children, context = flatten_fn(node)
if (
node_type is not deque # ignore mismatch of `maxlen` for deque
) and context != treespec._context:
@ -1256,6 +1278,7 @@ class TreeSpec:
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
"""Reconstruct a pytree from the leaves."""
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
@ -1301,7 +1324,7 @@ class TreeSpec:
return hash((node_type, hashable_context, tuple(self._children)))
PyTreeSpec: TypeAlias = TreeSpec
TreeSpec: TypeAlias = PyTreeSpec
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
@ -1363,6 +1386,45 @@ def treespec_dict(
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
def _is_pytreespec_instance(
obj: Any,
) -> TypeIs[Union[TreeSpec, "cxx_pytree.PyTreeSpec"]]:
if isinstance(obj, TreeSpec):
return True
if "torch.utils._cxx_pytree" in sys.modules:
# The C++ pytree module is not always available, so we check if it is loaded.
# If the C++ pytree module is loaded, we can check if the treespec
# is an instance of the C++ TreeSpec class.
import torch.utils._cxx_pytree as cxx_pytree
if isinstance(obj, cxx_pytree.PyTreeSpec):
return True
if "torch._dynamo.polyfills.pytree" in sys.modules:
# The PyTorch Dynamo pytree module is not always available, so we check if it is loaded.
# If the PyTorch Dynamo pytree module is loaded, we can check if the treespec
# is an instance of the PyTorch Dynamo TreeSpec class.
import torch._dynamo.polyfills.pytree as dynamo_pytree
if isinstance(obj, dynamo_pytree.PyTreeSpec):
return True
return False
def _ensure_python_treespec_instance(
treespec: Union[TreeSpec, "cxx_pytree.PyTreeSpec"],
) -> TreeSpec:
if isinstance(treespec, TreeSpec):
return treespec
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
dummy_tree = treespec.unflatten([0] * treespec.num_leaves)
return tree_structure(dummy_tree)
def tree_flatten(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
@ -1393,11 +1455,14 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
"""Given a list of values and a TreeSpec, builds a pytree.
This is the inverse operation of `tree_flatten`.
"""
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
f"instance of TreeSpec but got item of type {type(treespec)}.",
)
if not _is_pytreespec_instance(treespec):
if not _is_pytreespec_instance(leaves):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
# Allow passing the PyTreeSpec instance as the first argument
leaves, treespec = treespec, leaves
return treespec.unflatten(leaves)
@ -1827,34 +1892,30 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any] | None:
if not isinstance(treespec, TreeSpec):
raise AssertionError("treespec must be a TreeSpec")
def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any]:
result: list[Any] = []
if tree_is_leaf(tree, is_leaf=is_leaf):
return [tree] * treespec.num_leaves
if treespec.is_leaf():
def add_leaves(x: Any, subtree: PyTree) -> None:
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
result.extend([x] * subtreespec.num_leaves)
tree_map_(
add_leaves,
prefix_tree,
full_tree,
is_leaf=is_leaf,
)
return result
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
except ValueError:
return None
node_type = _get_node_type(tree)
if node_type != treespec.type:
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
# Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or context != treespec._context:
return None
# Recursively flatten the children
result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None:
result += flat
else:
return None
return result
@dataclasses.dataclass
@ -1968,11 +2029,7 @@ _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str:
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}.",
)
treespec = _ensure_python_treespec_instance(treespec)
if protocol is None:
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
@ -2001,16 +2058,22 @@ def treespec_loads(serialized: str) -> TreeSpec:
)
class _DummyLeaf:
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")
def __repr__(self) -> str:
return "*"
return "*" # no quotes
_asterisk = _Asterisk()
del _Asterisk
def treespec_pprint(treespec: TreeSpec) -> str:
dummy_tree = tree_unflatten(
[_DummyLeaf() for _ in range(treespec.num_leaves)],
treespec,
)
dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec)
return repr(dummy_tree)

View File

@ -0,0 +1,235 @@
# Owner(s): ["module: pytree"]
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `tree_map` to map a function over all Tensors inside some nested
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
"""
import os as _os
import sys as _sys
from types import ModuleType as _ModuleType
from typing import Any as _Any, Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING
import torch.utils._pytree as python
from torch.utils._pytree import ( # these type aliases are identical in both implementations
FlattenFunc,
FlattenWithKeysFunc,
FromDumpableContextFunc,
PyTree,
ToDumpableContextFunc,
UnflattenFunc,
)
if _TYPE_CHECKING:
import torch.utils._cxx_pytree as cxx
__all__ = [
"PyTreeSpec",
"register_pytree_node",
"tree_flatten",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_structure",
"tree_map",
"tree_map_",
"tree_map_only",
"tree_map_only_",
"tree_all",
"tree_any",
"tree_all_only",
"tree_any_only",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]
# NB: Once this variable is read from the environment, the underlying pytree
# implementation is frozen. It cannot be swapped to another at runtime.
PYTORCH_USE_CXX_PYTREE: bool = _os.getenv("PYTORCH_USE_CXX_PYTREE", "0") not in {
"0",
"",
}
def _import_cxx_pytree_and_store() -> _ModuleType:
if not python._cxx_pytree_dynamo_traceable:
raise ImportError(
"Cannot import package `optree`. "
"Please install `optree` via `python -m pip install --upgrade optree`. "
"Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`."
)
import torch.utils._cxx_pytree as cxx
# This allows the following statements to work properly:
#
# import torch.utils.pytree
#
# torch.utils.pytree.cxx
# torch.utils.pytree.cxx.tree_map
#
_sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
return cxx
if PYTORCH_USE_CXX_PYTREE:
cxx = _import_cxx_pytree_and_store() # noqa: F811
else:
cxx = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
_sys.modules[f"{__name__}.python"] = python
if cxx is not None:
_sys.modules[f"{__name__}.cxx"] = cxx
else:
del cxx
class LazyCxxModule(_ModuleType):
def __getattr__(self, name: str) -> _Any:
if name == "__name__":
return f"{__name__}.cxx"
if name == "__file__":
return python.__file__.removesuffix("_python.py") + "_cxx_pytree.py"
cxx = globals().get("cxx")
if cxx is None:
if name.startswith("_"):
raise AttributeError(
f"module {self.__name__!r} has not been imported yet: "
f"accessing attribute {name!r}. "
f"Please import {self.__name__!r} explicitly first."
)
# Lazy import on first member access
cxx = _import_cxx_pytree_and_store()
return getattr(cxx, name)
def __setattr__(self, name: str, value: _Any) -> None:
# Lazy import
cxx = _import_cxx_pytree_and_store()
return setattr(cxx, name, value)
# This allows the following statements to work properly:
#
# import torch.utils.pytree.cxx
# from torch.utils.pytree.cxx import tree_map
#
_sys.modules[f"{__name__}.cxx"] = LazyCxxModule(f"{__name__}.cxx")
del LazyCxxModule
if not PYTORCH_USE_CXX_PYTREE:
from torch.utils._pytree import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
PyTreeSpec,
register_pytree_node as _register_pytree_node,
tree_all,
tree_all_only,
tree_any,
tree_any_only,
tree_flatten,
tree_iter,
tree_leaves,
tree_map,
tree_map_,
tree_map_only,
tree_map_only_,
tree_structure,
tree_unflatten,
treespec_pprint,
)
else:
from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef]
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
PyTreeSpec,
register_pytree_node as _register_pytree_node,
tree_all,
tree_all_only,
tree_any,
tree_any_only,
tree_flatten,
tree_iter,
tree_leaves,
tree_map,
tree_map_,
tree_map_only,
tree_map_only_,
tree_structure,
tree_unflatten,
treespec_pprint,
)
def register_pytree_node(
cls: type[_Any],
/,
# intentionally use `*_func` over `*_fn` to match annotations
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
) -> None:
"""Register a container-like type as pytree node.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_func (callable): A function to be used during flattening, taking an instance of
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
passed to the ``unflatten_func``.
unflatten_func (callable): A function taking two arguments: the unflattened children, and
the auxiliary data that was returned by ``flatten_func`` and stored in the treespec.
The function should return an instance of ``cls``.
Example::
>>> # xdoctest: +SKIP
>>> from collections import UserList
... class MyList(UserList): pass
>>> # Registry a Python type with lambda functions
... register_pytree_node(
... MyList,
... lambda lst: (list(lst), None),
... lambda children, _: MyList(children),
... )
"""
_register_pytree_node(
cls,
flatten_func,
unflatten_func,
)
def __getattr__(name: str) -> _Any:
if name == "cxx":
# Lazy import
return _import_cxx_pytree_and_store()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,8 @@
# Owner(s): ["module: pytree"]
from .._cxx_pytree import * # previously public APIs # noqa: F403
from .._cxx_pytree import ( # non-public internal APIs
__all__ as __all__,
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
KeyPath as KeyPath,
)

View File

@ -0,0 +1,15 @@
# Owner(s): ["module: pytree"]
from .._pytree import * # previously public APIs # noqa: F403
from .._pytree import ( # non-public internal APIs
__all__ as __all__,
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
arg_tree_leaves as arg_tree_leaves,
BUILTIN_TYPES as BUILTIN_TYPES,
GetAttrKey as GetAttrKey,
KeyEntry as KeyEntry,
KeyPath as KeyPath,
MappingKey as MappingKey,
SequenceKey as SequenceKey,
SUPPORTED_NODES as SUPPORTED_NODES,
)