Compare commits

...

72 Commits

Author SHA1 Message Date
63375b0adb Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:29:48 -07:00
2cc9ff2f6a Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:29:48 -07:00
75eb58ee1d Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:27:09 -07:00
6d41a72865 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:27:09 -07:00
e4ad59038a Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 10:30:23 -07:00
150eac735f Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 10:30:23 -07:00
8d5f429f3b Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-27 16:22:44 -07:00
2ea430ce25 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-27 16:22:44 -07:00
e186bca40f Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 16:07:13 -07:00
602f1f42c2 Update on "[WIP][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 15:43:22 -07:00
24e0f27711 [WIP][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map
[ghstack-poisoned]
2025-10-21 11:43:31 -07:00
b9a6e020a2 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 10:22:06 -07:00
310db76d36 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 10:22:06 -07:00
4a70f2d033 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 15:41:23 -07:00
82fc673005 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 15:41:23 -07:00
e587947b20 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:55:15 -07:00
ff0ebf7fe5 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:55:15 -07:00
7a85a3d289 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 11:39:19 -07:00
baf793c54d Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 11:39:19 -07:00
e0c47235a3 [DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh
[ghstack-poisoned]
2025-10-14 17:13:09 -07:00
f6b4fe1c64 Update on "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:11:13 -07:00
092d3778a3 Update base for Update on "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:11:13 -07:00
82b416f4db Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:56:12 -07:00
36233be9d3 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:56:12 -07:00
f56e3c8fdf Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:25:44 -07:00
ad324d0e3c Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:25:44 -07:00
a4dc3dafee Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-12 21:41:06 -07:00
766493267e Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-12 21:41:06 -07:00
bd086ac1b3 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-01 15:35:10 -07:00
411a5c7f7f Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-01 15:35:10 -07:00
3bab82c453 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-30 16:06:59 -07:00
11a624dc28 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-30 16:06:59 -07:00
4736bb57e3 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-26 22:14:44 -07:00
c2aaa5664c Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-26 22:14:44 -07:00
569a9000a5 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-25 16:17:40 -07:00
73f13abc50 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-25 16:17:40 -07:00
3038d7f285 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-24 14:40:29 -07:00
57448253f3 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-24 14:40:29 -07:00
32caf41d72 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 21:09:56 -07:00
28185f4406 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 21:09:56 -07:00
92c709c202 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 14:41:26 -07:00
f094af1e1a Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 14:41:26 -07:00
09db0ef757 [For Discussion][DeviceMesh] Implement a concatenate api for submesh
[ghstack-poisoned]
2025-09-19 11:34:56 -07:00
ecd3f525d5 Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-18 21:00:54 -07:00
b2345c972f Update base for Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-18 21:00:54 -07:00
fe44a87ed4 Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-17 22:50:00 -07:00
f0aa9cfc42 Update base for Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-17 22:50:00 -07:00
0ea286d26d Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:37:29 -07:00
621b9f2be8 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:37:29 -07:00
5ac6d410aa Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:32:59 -07:00
2bb8d19968 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:32:59 -07:00
2cd038d95d Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:59:38 -07:00
5decc0e164 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:59:38 -07:00
bafbc39603 Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:00:28 -07:00
356abd0719 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:00:28 -07:00
457bdbfaa4 Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-25 11:23:34 -07:00
5196ba3db4 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-25 11:23:34 -07:00
bf108cf3c9 [WIP][device_mesh] Implement on top of CuTe layout bookkeeping
[ghstack-poisoned]
2025-08-21 16:50:56 -07:00
389519a03c Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:20:20 -07:00
873ec8442e Update base for Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:20:20 -07:00
5841ede067 Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:13:33 -07:00
4c90367bfb Update base for Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:13:32 -07:00
27910fb22a Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to implement the 


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 11:46:45 -07:00
e5cbce1780 Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 11:03:28 -07:00
4dd9c8cf2d Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 08:55:35 -07:00
8c0643a671 Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 22:34:22 -07:00
0f7736dd55 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 16:02:43 -07:00
9e8f81aaa8 [DeviceMesh] Simplifying internal bookkeeping with CuTe layout
[ghstack-poisoned]
2025-08-20 16:02:43 -07:00
72706c7cb9 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 08:31:24 -07:00
01a3a8550a Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-19 22:42:05 -07:00
00e03cccf3 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-19 21:27:37 -07:00
1a58e8dad5 [DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping
[ghstack-poisoned]
2025-08-19 15:57:40 -07:00
7 changed files with 364 additions and 154 deletions

View File

@ -1000,6 +1000,9 @@ class TestDeviceMeshGetItem(DTensorTestBase):
)
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
# test pg caching when unflatten into same layout.
self.assertEqual(non_ep_mesh["dp"].get_group(), ep_mesh["dp"].get_group())
self.assertEqual(non_ep_mesh["tp"].get_group(), ep_mesh["ep_tp"].get_group())
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))

View File

@ -1,5 +1,6 @@
#pragma once
#include <functional>
#include <memory>
#include <utility>
#include <vector>
@ -48,6 +49,12 @@ class TORCH_API Backend : public torch::CustomClassHolder {
const std::string backend;
std::string group_name;
std::vector<uint64_t> global_ranks_in_group;
bool operator==(const Options& other) const noexcept {
return timeout == other.timeout && backend == other.backend &&
group_name == other.group_name &&
global_ranks_in_group == other.global_ranks_in_group;
}
};
explicit Backend(int rank, int size);
@ -511,3 +518,24 @@ class TORCH_API Backend : public torch::CustomClassHolder {
};
} // namespace c10d
// small helper
inline void hash_combine(std::size_t& seed, std::size_t value) noexcept {
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
}
namespace std {
template <>
struct hash<c10d::Backend::Options> {
std::size_t operator()(const c10d::Backend::Options& o) const noexcept {
std::size_t h = 0;
hash_combine(h, std::hash<long long>{}(o.timeout.count()));
hash_combine(h, std::hash<std::string>{}(o.backend));
hash_combine(h, std::hash<std::string>{}(o.group_name));
for (auto x : o.global_ranks_in_group)
hash_combine(h, std::hash<uint64_t>{}(x));
return h;
}
};
} // namespace std

View File

@ -260,6 +260,23 @@ class TORCH_API ProcessGroupGloo : public Backend {
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
int threads;
bool operator==(const Options& other) const noexcept {
// 1) compare base first
if (!static_cast<const Backend::Options&>(*this).operator==(other))
return false;
// 2) compare devices by identity
if (devices.size() != other.devices.size())
return false;
for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i].get() != other.devices[i].get()) // pointer identity
return false;
}
// 3) compare added scalar fields
return threads == other.threads;
}
};
const std::string getBackendName() const override {
@ -494,4 +511,24 @@ class TORCH_API ProcessGroupGloo : public Backend {
} // namespace c10d
namespace std {
template <>
struct hash<c10d::ProcessGroupGloo::Options> {
std::size_t operator()(
const c10d::ProcessGroupGloo::Options& o) const noexcept {
std::size_t h = 0;
// reuse base hash
hash_combine(
h,
std::hash<c10d::Backend::Options>{}(
static_cast<const c10d::Backend::Options&>(o)));
// add derived fields
for (auto const& dev : o.devices)
hash_combine(h, std::hash<const void*>{}(dev.get()));
hash_combine(h, std::hash<int>{}(o.threads));
return h;
}
};
} // namespace std
#endif // USE_C10D_GLOO

View File

@ -550,6 +550,33 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead.
int split_color{-2};
#endif
bool operator==(const Options& other) const noexcept {
// 1) compare base first
if (!static_cast<const Backend::Options&>(*this).operator==(other))
return false;
// 2) simple fields
if (is_high_priority_stream != other.is_high_priority_stream) {
return false;
}
if (split_color != other.split_color) {
return false;
}
// 3) split_from: compare by identity
if (split_from.get() != other.split_from.get()) {
return false;
}
#ifdef NCCL_HAS_CONFIG
// 4) config
if (std::memcmp(&config, &other.config, sizeof(ncclConfig_t)) != 0) {
return false;
}
#endif
return true;
}
};
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
@ -1504,4 +1531,46 @@ typedef bool (*gil_checker_t)();
TORCH_API gil_checker_t& get_gil_checker();
} // namespace c10d
#ifdef NCCL_HAS_CONFIG
inline std::size_t hash_nccl_config(const ncclConfig_t& cfg) noexcept {
const unsigned char* p = reinterpret_cast<const unsigned char*>(&cfg);
std::size_t h = 0;
for (std::size_t i = 0; i < sizeof(cfg); ++i) {
hash_combine(h, static_cast<std::size_t>(p[i]));
}
return h;
}
#endif
namespace std {
template <>
struct hash<c10d::ProcessGroupNCCL::Options> {
std::size_t operator()(
const c10d::ProcessGroupNCCL::Options& o) const noexcept {
std::size_t h = 0;
// 1) base
hash_combine(
h,
std::hash<c10d::Backend::Options>{}(
static_cast<const c10d::Backend::Options&>(o)));
// 2) trivial extras
hash_combine(h, std::hash<bool>{}(o.is_high_priority_stream));
hash_combine(h, std::hash<int>{}(o.split_color));
// 3) pointer identity for split_from
hash_combine(h, std::hash<const void*>{}(o.split_from.get()));
#ifdef NCCL_HAS_CONFIG
// 4) config — option A: hash bytes
hash_combine(h, hash_nccl_config(o.config));
#endif
return h;
}
};
} // namespace std
#endif // USE_C10D_NCCL

View File

@ -3107,7 +3107,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def_readwrite(
"global_ranks_in_group",
&::c10d::Backend::Options::global_ranks_in_group)
.def_readwrite("group_name", &::c10d::Backend::Options::group_name);
.def_readwrite("group_name", &::c10d::Backend::Options::group_name)
.def(
"__eq__",
[](const ::c10d::Backend::Options& a,
const ::c10d::Backend::Options& b) { return a == b; })
.def("__hash__", [](const ::c10d::Backend::Options& a) {
return std::hash<::c10d::Backend::Options>{}(a);
});
#ifdef USE_C10D_GLOO
auto processGroupGloo =
@ -3121,7 +3128,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
processGroupGloo, "_Options", backendOptions)
.def(py::init<>())
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads)
.def(
"__eq__",
[](const ::c10d::ProcessGroupGloo::Options& a,
const ::c10d::ProcessGroupGloo::Options& b) { return a == b; })
.def("__hash__", [](const ::c10d::ProcessGroupGloo::Options& a) {
return std::hash<::c10d::ProcessGroupGloo::Options>{}(a);
});
processGroupGloo
.def_static(
@ -3481,6 +3495,15 @@ Example::
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
.def_readwrite(
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
.def(
"__eq__",
[](const ::c10d::ProcessGroupNCCL::Options& a,
const ::c10d::ProcessGroupNCCL::Options& b) { return a == b; })
.def(
"__hash__",
[](const ::c10d::ProcessGroupNCCL::Options& a) {
return std::hash<::c10d::ProcessGroupNCCL::Options>{}(a);
})
.def(
"__copy__",
[](const ::c10d::ProcessGroupNCCL::Options& self) {

View File

@ -865,7 +865,9 @@ class _LocalDeviceMesh:
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
for r in lm.ranks:
rank_tensor = self._layout.remap_to_tensor(self._rank_map)
rank_tensor = self._layout.remap_to_tensor(
self._shared_state.get_rank_map()
)
rank_coords = (rank_tensor == r).nonzero().tolist()
assert len(rank_coords) == 1
for d, c in enumerate(rank_coords[0][1:]):

View File

@ -125,6 +125,172 @@ else:
"""
return getattr(torch, device_type, None)
class _SharedState:
"""
This class is used to store the shared state of the DeviceMesh.
"""
_rank_map: torch.Tensor
_root_mesh: Optional["DeviceMesh"]
_backend_cache: dict[tuple[_MeshLayout, Optional[C10dBackend.Options]], str]
def __init__(
self, rank_map: torch.Tensor, root_mesh: Optional["DeviceMesh"] = None
) -> None:
self._rank_map = rank_map
self._root_mesh = root_mesh
self._backend_cache: dict[
tuple[_MeshLayout, Optional[C10dBackend.Options]], str
] = {}
def get_rank_map(self) -> torch.Tensor:
return self._rank_map
def get_root_mesh(self) -> Optional["DeviceMesh"]:
return self._root_mesh
def update_backend_cache(
self,
layout: _MeshLayout,
backend: str,
pg_option: Optional[C10dBackend.Options],
) -> None:
if (layout, pg_option) not in self._backend_cache:
self._backend_cache[(layout, pg_option)] = backend
def get_backend_from_cache(
self, layout: _MeshLayout, pg_option: Optional[C10dBackend.Options]
) -> Optional[str]:
return self._backend_cache.get((layout, pg_option), None)
def _init_one_process_group(
self,
sub_layout: _MeshLayout,
dim_name: str,
backend_override: BackendConfig,
) -> Optional[str]:
# Generate a 2D global mesh tensor for the current dim for PG creation.
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(self._rank_map)
backend, pg_options = backend_override
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
# TODO: remove this once we have fixed inside c10d level.
timeout = pg_options._timeout if pg_options else None
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
# If the mesh doesn't have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = f"mesh_{dim_name}"
dim_group = None
default_group = _get_default_group()
# Early return if there is only one sub_layout in the mesh layout.
if sub_layout.numel() == get_world_size() and backend_override == (
None,
None,
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
# Otherwise, create new pg.
ranks = list(range(get_world_size()))
dim_group = (
new_group(
backend="cpu:gloo,cuda:nccl",
ranks=ranks,
group_desc="mesh_default",
)
if torch.cuda.is_available()
and get_backend(default_group) == "gloo"
else default_group
)
return dim_group.group_name # type: ignore[union-attr]
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
# for each mesh dimension. In a 2 * 4 mesh, we only need to make two API calls per ranks to create
# all the subgroups.
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
# mesh, we need to make two API calls per ranks to create all the subgroups.
if (
getattr(default_group, "bound_device_id", None) is not None
and torch.cuda.is_available()
and (
backend is None
or default_group._get_backend(torch.device("cuda")).name()
== backend
)
):
dim_group = split_group(
parent_pg=default_group,
timeout=timeout,
pg_options=pg_options,
split_ranks=pg_ranks_by_dim.tolist(),
group_desc=group_desc,
)
return dim_group.group_name # type: ignore[union-attr]
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
# along with appending information to the `dim_group_names` list whenever necessary.
pg_name = None
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
dim_group = new_group(
ranks=subgroup_ranks,
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
# only add to dim_groups if the current rank in the subgroup
if get_rank() in subgroup_ranks:
if pg_name is not None:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"in {subgroup_ranks}!"
)
pg_name = dim_group.group_name
return pg_name
def _init_process_groups(
self,
layout: _MeshLayout,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
dim_group_names: list[str] = []
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}"
backend_cache = self.get_backend_from_cache(
layout[dim], backend_override[dim][1]
)
if backend_cache is not None:
dim_group_names.append(backend_cache)
else:
dim_group_names.append(
self._init_one_process_group( # type: ignore[arg-type]
layout[dim], dim_name, backend_override[dim]
)
)
if dim_group_names[-1] is not None:
self.update_backend_cache(
layout[dim], dim_group_names[-1], backend_override[dim][1]
)
if any(n is None for n in dim_group_names):
assert all(n is None for n in dim_group_names)
return []
return dim_group_names
torch.serialization.add_safe_globals([_SharedState])
class DeviceMesh:
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
@ -174,12 +340,11 @@ else:
"""
_device_type: str
_rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None
# Record flatten mesh name to its flattened mesh in root mesh.
_flatten_mapping: dict[str, "DeviceMesh"]
_shared_state: _SharedState
def __init__(
self,
@ -191,13 +356,12 @@ else:
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_rank_map: Optional[torch.Tensor] = None,
_root_mesh: Optional["DeviceMesh"] = None,
_shared_state: Optional[_SharedState] = None,
) -> None:
if mesh is not None:
if _layout is not None or _rank_map is not None:
if _layout is not None or _shared_state is not None:
raise TypeError(
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
"Cannot provide _layout and/or _shared_state if passing explicit mesh"
)
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
@ -207,28 +371,31 @@ else:
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
_rank_map = mesh_tensor.flatten()
rank_map = mesh_tensor.flatten()
self._shared_state = _SharedState(rank_map, self)
else:
if _layout is None or _rank_map is None:
if _layout is None or _shared_state is None:
raise TypeError(
"The mesh argument is required except for PRIVATE USAGE ONLY!"
)
rank_map = _shared_state.get_rank_map()
self._shared_state = _shared_state
if self._shared_state.get_root_mesh() is None:
self._shared_state._root_mesh = self
assert _layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
assert _rank_map.numel() >= _layout.cosize(), (
f"The rank map contains {_rank_map.numel()} element, "
assert rank_map.ndim == 1, "The rank map must be 1-dimensional"
assert rank_map.is_contiguous(), "The rank map must be contiguous"
assert rank_map.numel() >= _layout.cosize(), (
f"The rank map contains {rank_map.numel()} element, "
f"which isn't large enough for layout {_layout}"
)
self._device_type = device_type
self._layout = _layout
self._rank_map = _rank_map
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
self._root_mesh = _root_mesh
if backend_override is None:
backend_override = ((None, None),) * len(self._layout)
@ -247,9 +414,8 @@ else:
# process (we need to know if the current global rank is in the mesh or not).
if _init_backend:
self._setup_world_group_and_device()
self._dim_group_names = self._init_process_groups(
self._dim_group_names = self._shared_state._init_process_groups(
self._layout,
self._rank_map,
self._mesh_dim_names,
backend_override,
)
@ -269,7 +435,7 @@ else:
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
self._flatten_rank_map = tuple(rank_map.tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@ -281,7 +447,7 @@ else:
@property
def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices."""
full_mesh = self._layout.remap_to_tensor(self._rank_map)
full_mesh = self._layout.remap_to_tensor(self._shared_state.get_rank_map())
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
@ -349,125 +515,9 @@ else:
return _get_default_group()
@staticmethod
def _init_process_groups(
layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
#
dim_group_names: list[str] = []
default_group = _get_default_group()
if (
len(layout) == 1
and layout.numel() == get_world_size()
and backend_override[0] == (None, None)
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
# Otherwise, create new pg.
ranks = list(range(get_world_size()))
dim_group = (
new_group(
backend="cpu:gloo,cuda:nccl",
ranks=ranks,
group_desc="mesh_default",
)
if torch.cuda.is_available()
and get_backend(default_group) == "gloo"
else default_group
)
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
# TODO: remove this once we have fixed inside c10d level.
timeout = pg_options._timeout if pg_options else None
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = (
f"mesh_{mesh_dim_names[dim]}"
if mesh_dim_names
else f"mesh_dim_{dim}"
)
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
# for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create
# all the subgroups.
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
# mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
dim_group = None
has_split_group = False
if (
(
bound_device_id := getattr(
default_group, "bound_device_id", None
)
)
is not None
and torch.cuda.is_available()
and (
backend is None
or default_group._get_backend(torch.device("cuda")).name()
== backend
)
):
dim_group = split_group(
parent_pg=default_group,
timeout=timeout,
pg_options=pg_options,
split_ranks=pg_ranks_by_dim.tolist(),
group_desc=group_desc,
)
has_split_group = True
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
# along with appending information to the `dim_group_names` list whenever necessary.
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
# Temporarily reverting to resolve test timeout while root-causing.
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
# pyrefly: ignore [unbound-name]
if bound_device_id is None or not has_split_group:
dim_group = new_group(
ranks=subgroup_ranks,
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
# only add to dim_groups if the current rank in the subgroup
if get_rank() in subgroup_ranks:
if len(dim_group_names) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"in {subgroup_ranks}!"
)
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
return dim_group_names
def _get_root_mesh(self) -> "DeviceMesh":
return self._root_mesh if self._root_mesh else self
root_mesh = self._shared_state.get_root_mesh()
return root_mesh if root_mesh is not None else self
def __enter__(self) -> "DeviceMesh":
# set this mesh as the current mesh in mesh env
@ -668,10 +718,9 @@ else:
res_submesh = DeviceMesh(
self._device_type,
_layout=layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=submesh_dim_names,
_root_mesh=root_mesh,
_init_backend=False,
_shared_state=root_mesh._shared_state,
)
res_submesh._dim_group_names = slice_dim_group_name
return res_submesh
@ -718,10 +767,9 @@ else:
res_flattened_mesh = DeviceMesh(
root_mesh._device_type,
_layout=flattened_mesh_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=(mesh_dim_name,),
_root_mesh=root_mesh,
backend_override=(backend_override,),
_shared_state=root_mesh._shared_state,
)
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
@ -852,7 +900,7 @@ else:
"""
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
layout = self._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
pg_ranks_by_dim = layout.remap_to_tensor(self._shared_state.get_rank_map())
cur_rank = self.get_rank()
res_submeshes = []
for mesh_1d in pg_ranks_by_dim:
@ -1095,10 +1143,9 @@ else:
res_mesh = DeviceMesh(
self.device_type,
_layout=unflattened_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=tuple(unflattened_mesh_dim_names),
_root_mesh=root_mesh,
_init_backend=False,
_shared_state=root_mesh._shared_state,
)
# If original mesh has initiated its backend, we need to initialize the backend
@ -1107,11 +1154,12 @@ else:
# per dim backend init.
if hasattr(self, "_dim_group_names"):
dim_group_names = self._dim_group_names.copy()
dim_group_names[dim : dim + 1] = self._init_process_groups(
partial_layout,
root_mesh._rank_map,
mesh_dim_names,
backend_override,
dim_group_names[dim : dim + 1] = (
root_mesh._shared_state._init_process_groups(
partial_layout,
mesh_dim_names,
backend_override,
)
)
res_mesh._dim_group_names = dim_group_names
@ -1208,10 +1256,9 @@ else:
res_mesh = DeviceMesh(
device_mesh_list[0].device_type,
_layout=concat_mesh_layout,
_rank_map=device_mesh_list[0]._rank_map,
mesh_dim_names=tuple(concat_dim_names),
_root_mesh=device_mesh_list[0]._get_root_mesh(),
_init_backend=False,
_shared_state=device_mesh_list[0]._shared_state,
)
res_mesh._dim_group_names = concat_dim_group_name
return res_mesh
@ -1340,12 +1387,13 @@ else:
# external device type has been set to be (e.g. meta)
with torch.device("cpu"):
rank_map = torch.arange(layout.numel(), dtype=torch.int)
shared_state = _SharedState(rank_map)
device_mesh = DeviceMesh(
device_type=device_type,
_layout=layout,
_rank_map=rank_map,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override_tuple,
_shared_state=shared_state,
)
return device_mesh