Compare commits

...

78 Commits

Author SHA1 Message Date
54abae7c07 Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-14 10:23:04 -08:00
6a5dc667a4 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-14 10:23:04 -08:00
6aca75eab9 Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-14 09:20:33 -08:00
f8d0bef572 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-14 09:20:33 -08:00
23099ad498 Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-14 09:07:17 -08:00
1842dde349 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-14 09:07:17 -08:00
63375b0adb Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:29:48 -07:00
2cc9ff2f6a Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:29:48 -07:00
75eb58ee1d Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:27:09 -07:00
6d41a72865 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:27:09 -07:00
e4ad59038a Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 10:30:23 -07:00
150eac735f Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 10:30:23 -07:00
8d5f429f3b Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-27 16:22:44 -07:00
2ea430ce25 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-27 16:22:44 -07:00
e186bca40f Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 16:07:13 -07:00
602f1f42c2 Update on "[WIP][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 15:43:22 -07:00
24e0f27711 [WIP][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map
[ghstack-poisoned]
2025-10-21 11:43:31 -07:00
b9a6e020a2 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 10:22:06 -07:00
310db76d36 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 10:22:06 -07:00
4a70f2d033 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 15:41:23 -07:00
82fc673005 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 15:41:23 -07:00
e587947b20 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:55:15 -07:00
ff0ebf7fe5 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:55:15 -07:00
7a85a3d289 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 11:39:19 -07:00
baf793c54d Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 11:39:19 -07:00
e0c47235a3 [DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh
[ghstack-poisoned]
2025-10-14 17:13:09 -07:00
f6b4fe1c64 Update on "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:11:13 -07:00
092d3778a3 Update base for Update on "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:11:13 -07:00
82b416f4db Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:56:12 -07:00
36233be9d3 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:56:12 -07:00
f56e3c8fdf Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:25:44 -07:00
ad324d0e3c Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:25:44 -07:00
a4dc3dafee Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-12 21:41:06 -07:00
766493267e Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-12 21:41:06 -07:00
bd086ac1b3 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-01 15:35:10 -07:00
411a5c7f7f Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-01 15:35:10 -07:00
3bab82c453 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-30 16:06:59 -07:00
11a624dc28 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-30 16:06:59 -07:00
4736bb57e3 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-26 22:14:44 -07:00
c2aaa5664c Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-26 22:14:44 -07:00
569a9000a5 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-25 16:17:40 -07:00
73f13abc50 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-25 16:17:40 -07:00
3038d7f285 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-24 14:40:29 -07:00
57448253f3 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-24 14:40:29 -07:00
32caf41d72 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 21:09:56 -07:00
28185f4406 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 21:09:56 -07:00
92c709c202 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 14:41:26 -07:00
f094af1e1a Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 14:41:26 -07:00
09db0ef757 [For Discussion][DeviceMesh] Implement a concatenate api for submesh
[ghstack-poisoned]
2025-09-19 11:34:56 -07:00
ecd3f525d5 Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-18 21:00:54 -07:00
b2345c972f Update base for Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-18 21:00:54 -07:00
fe44a87ed4 Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-17 22:50:00 -07:00
f0aa9cfc42 Update base for Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-17 22:50:00 -07:00
0ea286d26d Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:37:29 -07:00
621b9f2be8 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:37:29 -07:00
5ac6d410aa Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:32:59 -07:00
2bb8d19968 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:32:59 -07:00
2cd038d95d Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:59:38 -07:00
5decc0e164 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:59:38 -07:00
bafbc39603 Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:00:28 -07:00
356abd0719 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:00:28 -07:00
457bdbfaa4 Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-25 11:23:34 -07:00
5196ba3db4 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-25 11:23:34 -07:00
bf108cf3c9 [WIP][device_mesh] Implement on top of CuTe layout bookkeeping
[ghstack-poisoned]
2025-08-21 16:50:56 -07:00
389519a03c Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:20:20 -07:00
873ec8442e Update base for Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:20:20 -07:00
5841ede067 Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:13:33 -07:00
4c90367bfb Update base for Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:13:32 -07:00
27910fb22a Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to implement the 


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 11:46:45 -07:00
e5cbce1780 Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 11:03:28 -07:00
4dd9c8cf2d Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 08:55:35 -07:00
8c0643a671 Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 22:34:22 -07:00
0f7736dd55 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 16:02:43 -07:00
9e8f81aaa8 [DeviceMesh] Simplifying internal bookkeeping with CuTe layout
[ghstack-poisoned]
2025-08-20 16:02:43 -07:00
72706c7cb9 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 08:31:24 -07:00
01a3a8550a Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-19 22:42:05 -07:00
00e03cccf3 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-19 21:27:37 -07:00
1a58e8dad5 [DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping
[ghstack-poisoned]
2025-08-19 15:57:40 -07:00
8 changed files with 638 additions and 270 deletions

View File

@ -10,9 +10,8 @@ from numpy.testing import assert_array_equal
import torch
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
@ -554,6 +553,40 @@ class DTensorTest(DTensorTestBase):
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
@with_comms
def test_dtensor_save_load_with_mesh_backend_decouple(self):
import io
# Turn on gate for not saving PG names for device mesh when it comes to torch.save.
DeviceMesh.decouple_backend_at_save = True
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=False)
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
self.assertNotEqual(sharded_tensor._spec.mesh, reloaded_st._spec.mesh)
self.assertEqual(
sharded_tensor.to_local().tolist(), reloaded_st.to_local().tolist()
)
self.assertEqual(sharded_tensor._spec.placements, reloaded_st._spec.placements)
reloaded_st._spec.mesh = device_mesh
self.assertEqual(sharded_tensor, reloaded_st)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
self.assertNotEqual(sharded_tensor._spec.mesh, reloaded_st._spec.mesh)
self.assertEqual(
sharded_tensor.to_local().tolist(), reloaded_st.to_local().tolist()
)
self.assertEqual(sharded_tensor._spec.placements, reloaded_st._spec.placements)
reloaded_st._spec.mesh = device_mesh
self.assertEqual(sharded_tensor, reloaded_st)
DeviceMesh.decouple_backend_at_save = False
@skipIfHpu
@with_comms
@unittest.skipIf(
@ -641,6 +674,7 @@ DTensorTestWithLocalTensor = create_local_tensor_test_class(
# integration
"test_dtensor_save_load",
"test_dtensor_save_load_import",
"test_dtensor_save_load_with_mesh_backend_decouple",
],
)

View File

@ -1051,6 +1051,26 @@ class TestDeviceMeshGetItem(DTensorTestBase):
)
w.wait()
@with_comms
def test_unflatten_mesh_3d_with_pg_cache(self):
# Turn on gate for not saving PG names for device mesh when it comes to torch.save.
# This also turns on pg cache
DeviceMesh.decouple_backend_at_save = True
# Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
global_mesh = init_device_mesh(
self.device_type,
(8,),
mesh_dim_names=("world",),
)
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
# test pg caching when unflatten into same layout.
self.assertEqual(non_ep_mesh["dp"].get_group(), ep_mesh["dp"].get_group())
self.assertEqual(non_ep_mesh["tp"].get_group(), ep_mesh["ep_tp"].get_group())
DeviceMesh.decouple_backend_at_save = False
@with_comms
def test_concatenate_2d(self):
mesh_shape = (2, 4)

View File

@ -1,5 +1,6 @@
#pragma once
#include <functional>
#include <memory>
#include <utility>
#include <vector>
@ -48,6 +49,12 @@ class TORCH_API Backend : public torch::CustomClassHolder {
const std::string backend;
std::string group_name;
std::vector<uint64_t> global_ranks_in_group;
bool operator==(const Options& other) const noexcept {
return timeout == other.timeout && backend == other.backend &&
group_name == other.group_name &&
global_ranks_in_group == other.global_ranks_in_group;
}
};
explicit Backend(int rank, int size);
@ -511,3 +518,24 @@ class TORCH_API Backend : public torch::CustomClassHolder {
};
} // namespace c10d
// small helper
inline void hash_combine(std::size_t& seed, std::size_t value) noexcept {
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
}
namespace std {
template <>
struct hash<c10d::Backend::Options> {
std::size_t operator()(const c10d::Backend::Options& o) const noexcept {
std::size_t h = 0;
hash_combine(h, std::hash<long long>{}(o.timeout.count()));
hash_combine(h, std::hash<std::string>{}(o.backend));
hash_combine(h, std::hash<std::string>{}(o.group_name));
for (auto x : o.global_ranks_in_group)
hash_combine(h, std::hash<uint64_t>{}(x));
return h;
}
};
} // namespace std

View File

@ -260,6 +260,23 @@ class TORCH_API ProcessGroupGloo : public Backend {
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
int threads;
bool operator==(const Options& other) const noexcept {
// 1) compare base first
if (!static_cast<const Backend::Options&>(*this).operator==(other))
return false;
// 2) compare devices by identity
if (devices.size() != other.devices.size())
return false;
for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i].get() != other.devices[i].get()) // pointer identity
return false;
}
// 3) compare added scalar fields
return threads == other.threads;
}
};
const std::string getBackendName() const override {
@ -494,4 +511,24 @@ class TORCH_API ProcessGroupGloo : public Backend {
} // namespace c10d
namespace std {
template <>
struct hash<c10d::ProcessGroupGloo::Options> {
std::size_t operator()(
const c10d::ProcessGroupGloo::Options& o) const noexcept {
std::size_t h = 0;
// reuse base hash
hash_combine(
h,
std::hash<c10d::Backend::Options>{}(
static_cast<const c10d::Backend::Options&>(o)));
// add derived fields
for (auto const& dev : o.devices)
hash_combine(h, std::hash<const void*>{}(dev.get()));
hash_combine(h, std::hash<int>{}(o.threads));
return h;
}
};
} // namespace std
#endif // USE_C10D_GLOO

View File

@ -550,6 +550,33 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead.
int split_color{-2};
#endif
bool operator==(const Options& other) const noexcept {
// 1) compare base first
if (!static_cast<const Backend::Options&>(*this).operator==(other))
return false;
// 2) simple fields
if (is_high_priority_stream != other.is_high_priority_stream) {
return false;
}
if (split_color != other.split_color) {
return false;
}
// 3) split_from: compare by identity
if (split_from.get() != other.split_from.get()) {
return false;
}
#ifdef NCCL_HAS_CONFIG
// 4) config
if (std::memcmp(&config, &other.config, sizeof(ncclConfig_t)) != 0) {
return false;
}
#endif
return true;
}
};
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
@ -1504,4 +1531,46 @@ typedef bool (*gil_checker_t)();
TORCH_API gil_checker_t& get_gil_checker();
} // namespace c10d
#ifdef NCCL_HAS_CONFIG
inline std::size_t hash_nccl_config(const ncclConfig_t& cfg) noexcept {
const unsigned char* p = reinterpret_cast<const unsigned char*>(&cfg);
std::size_t h = 0;
for (std::size_t i = 0; i < sizeof(cfg); ++i) {
hash_combine(h, static_cast<std::size_t>(p[i]));
}
return h;
}
#endif
namespace std {
template <>
struct hash<c10d::ProcessGroupNCCL::Options> {
std::size_t operator()(
const c10d::ProcessGroupNCCL::Options& o) const noexcept {
std::size_t h = 0;
// 1) base
hash_combine(
h,
std::hash<c10d::Backend::Options>{}(
static_cast<const c10d::Backend::Options&>(o)));
// 2) trivial extras
hash_combine(h, std::hash<bool>{}(o.is_high_priority_stream));
hash_combine(h, std::hash<int>{}(o.split_color));
// 3) pointer identity for split_from
hash_combine(h, std::hash<const void*>{}(o.split_from.get()));
#ifdef NCCL_HAS_CONFIG
// 4) config — option A: hash bytes
hash_combine(h, hash_nccl_config(o.config));
#endif
return h;
}
};
} // namespace std
#endif // USE_C10D_NCCL

View File

@ -3107,7 +3107,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def_readwrite(
"global_ranks_in_group",
&::c10d::Backend::Options::global_ranks_in_group)
.def_readwrite("group_name", &::c10d::Backend::Options::group_name);
.def_readwrite("group_name", &::c10d::Backend::Options::group_name)
.def(
"__eq__",
[](const ::c10d::Backend::Options& a,
const ::c10d::Backend::Options& b) { return a == b; })
.def("__hash__", [](const ::c10d::Backend::Options& a) {
return std::hash<::c10d::Backend::Options>{}(a);
});
#ifdef USE_C10D_GLOO
auto processGroupGloo =
@ -3121,7 +3128,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
processGroupGloo, "_Options", backendOptions)
.def(py::init<>())
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads)
.def(
"__eq__",
[](const ::c10d::ProcessGroupGloo::Options& a,
const ::c10d::ProcessGroupGloo::Options& b) { return a == b; })
.def("__hash__", [](const ::c10d::ProcessGroupGloo::Options& a) {
return std::hash<::c10d::ProcessGroupGloo::Options>{}(a);
});
processGroupGloo
.def_static(
@ -3481,6 +3495,15 @@ Example::
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
.def_readwrite(
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
.def(
"__eq__",
[](const ::c10d::ProcessGroupNCCL::Options& a,
const ::c10d::ProcessGroupNCCL::Options& b) { return a == b; })
.def(
"__hash__",
[](const ::c10d::ProcessGroupNCCL::Options& a) {
return std::hash<::c10d::ProcessGroupNCCL::Options>{}(a);
})
.def(
"__copy__",
[](const ::c10d::ProcessGroupNCCL::Options& self) {

View File

@ -951,7 +951,9 @@ class _LocalDeviceMesh:
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
for r in lm.ranks:
rank_tensor = self._layout.remap_to_tensor(self._rank_map)
rank_tensor = self._layout.remap_to_tensor(
self._shared_state.get_rank_map()
)
rank_coords = (rank_tensor == r).nonzero().tolist()
assert len(rank_coords) == 1
for d, c in enumerate(rank_coords[0][1:]):

View File

@ -6,7 +6,7 @@ import threading
import warnings
from collections.abc import Iterator
from itertools import zip_longest
from typing import Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
@ -125,264 +125,67 @@ else:
"""
return getattr(torch, device_type, None)
class DeviceMesh:
class _SharedState:
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
represented as a n-d dimension array, and each value of the n-d dimensional
array is the global id of the default process group ranks.
DeviceMesh could be used to setup the N dimensional device connections across the cluster,
and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects
already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization),
and will select/set the device for the current process if user does not set the device
beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
DeviceMesh can also be used as a context manager when using together with DTensor APIs.
.. note::
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
is running on all processes/ranks in the cluster. Therefore, users need to make sure the
`mesh` array (which describes the layout of devices) should be identical across all ranks.
Inconsistent `mesh` will lead to silent hang.
Args:
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
of devices, where the IDs are global IDs of the default process group.
_rank (int): (experimental/internal)
The global rank of the current process. If not provided, it will
be inferred from the default process group.
Returns:
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
The following program runs on each process/rank in an SPMD manner. In this example, we have 2
hosts with 4 GPUs each.
A reduction over the first dimension of mesh will reduce across
columns (0, 4), .. and (3, 7), a reduction over the second dimension
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
Example::
>>> # xdoctest: +SKIP("no rank")
>>> from torch.distributed.device_mesh import DeviceMesh
>>>
>>> # Initialize device mesh as (2, 4) to represent the topology
>>> # of cross-host(dim 0), and within-host (dim 1).
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
This class is used to store the shared state of the DeviceMesh.
"""
# Flag to specify device save without backend info. This is a temporary variable
# We will remove this flag once we fully deprecate the behavior of save a device mesh with pg names.
_device_type: str
_rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None
# Record flatten mesh name to its flattened mesh in root mesh.
_flatten_mapping: dict[str, "DeviceMesh"]
_root_mesh: Optional["DeviceMesh"]
_backend_cache: dict[tuple[_MeshLayout, Optional[C10dBackend.Options]], str]
def __init__(
self,
device_type: str,
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
*,
mesh_dim_names: Optional[tuple[str, ...]] = None,
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_rank_map: Optional[torch.Tensor] = None,
_root_mesh: Optional["DeviceMesh"] = None,
rank_map: torch.Tensor,
root_mesh: Optional["DeviceMesh"] = None,
) -> None:
if mesh is not None:
if _layout is not None or _rank_map is not None:
raise TypeError(
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
)
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
mesh_tensor = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
_rank_map = mesh_tensor.flatten()
else:
if _layout is None or _rank_map is None:
raise TypeError(
"The mesh argument is required except for PRIVATE USAGE ONLY!"
)
assert _layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
assert _rank_map.numel() >= _layout.cosize(), (
f"The rank map contains {_rank_map.numel()} element, "
f"which isn't large enough for layout {_layout}"
)
self._device_type = device_type
self._layout = _layout
self._rank_map = _rank_map
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
self._root_mesh = _root_mesh
self._rank_map = rank_map
self._root_mesh = root_mesh
self._backend_cache: dict[
tuple[_MeshLayout, Optional[C10dBackend.Options]], str
] = {}
self.pg_cache_enabled = DeviceMesh.decouple_backend_at_save
if backend_override is None:
backend_override = ((None, None),) * len(self._layout)
elif len(backend_override) != len(self._layout):
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {len(self._layout)}."
)
# Internal bookkeeping for the device mesh.
self._layout = (
_layout
if _layout
else _MeshLayout(self.mesh.size(), self.mesh.stride())
)
if not self._layout.check_non_overlap():
raise AssertionError(
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
if self._layout.numel() != self.mesh.numel():
raise AssertionError(
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
)
def __post_init__(self):
assert self._rank_map.ndim == 1, "The rank map must be 1-dimensional"
assert self._rank_map.is_contiguous(), "The rank map must be contiguous"
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
def get_rank_map(self) -> torch.Tensor:
return self._rank_map
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not).
if _init_backend:
self._setup_world_group_and_device()
self._dim_group_names = self._init_process_groups(
self._layout,
self._rank_map,
self._mesh_dim_names,
backend_override,
)
def get_root_mesh(self) -> Optional["DeviceMesh"]:
return self._root_mesh
if is_initialized() and get_backend() == "threaded":
# pyrefly: ignore [bad-assignment]
self._thread_id = threading.get_ident()
if _rank is None:
_rank = get_rank()
# calculate the coordinates of the current global rank on the mesh
rank_coords = (self.mesh == _rank).nonzero()
if rank_coords.size(0) not in (0, 1):
raise AssertionError(
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
)
self._coordinate_on_dim: Optional[list[int]] = (
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""
def get_device_type(self) -> str:
return self._device_type
@property
def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices."""
full_mesh = self._layout.remap_to_tensor(self._rank_map)
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
if my_coords.size(0) > 0:
return full_mesh[my_coords[0, 0]]
raise RuntimeError(
"In order to get the mesh Tensor of a DeviceMesh it needs to "
"either have all its original dimensions (e.g., no slicing) "
"or it needs to contain the local rank"
)
def update_backend_cache(
self,
layout: _MeshLayout,
backend: str,
pg_option: Optional[C10dBackend.Options],
) -> None:
if (layout, pg_option) not in self._backend_cache:
self._backend_cache[(layout, pg_option)] = backend
@property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
"""Returns the names of mesh dimensions."""
return self._mesh_dim_names
def get_backend_from_cache(
self, layout: _MeshLayout, pg_option: Optional[C10dBackend.Options]
) -> Optional[str]:
return self._backend_cache.get((layout, pg_option), None)
def _setup_world_group_and_device(self):
default_initialized = is_initialized()
# TODO: think about how to allow pg options to be passed to world group
# or mesh dimension groups
if not default_initialized:
init_process_group()
world_size = get_world_size()
if self._layout.numel() > world_size:
raise RuntimeError(
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
)
# ONLY set the device if the current device is not initialized, if user already
# set the device before DeviceMesh init, we respect the user's choice.
device_handle = _get_device_handle(self._device_type)
if device_handle and not device_handle.is_initialized():
# auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK
# env variable from launchers, we use it to set the device.
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
logger.info(
"Setting default device for the current process based on LOCAL_RANK=%s",
local_rank,
)
device_handle.set_device(local_rank)
else:
warnings.warn(
"It seems like you did not set/select the default device for the current process before the DeviceMesh "
"initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. "
"It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
"the underlying communicator (i.e. NCCL) can be initialized properly. "
"Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
"device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. ",
stacklevel=2,
)
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_devices_per_host = device_handle.device_count()
if (
world_size > num_devices_per_host
and world_size % num_devices_per_host != 0
):
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_devices_per_host} {self._device_type} devices!"
)
device_handle.set_device(get_rank() % num_devices_per_host)
return _get_default_group()
@staticmethod
def _init_one_process_group(
self,
sub_layout: _MeshLayout,
rank_map: torch.Tensor,
dim_name: str,
backend_override: BackendConfig,
) -> Optional[str]:
# Generate a 2D global mesh tensor for the current dim for PG creation.
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(self._rank_map)
backend, pg_options = backend_override
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
@ -469,10 +272,9 @@ else:
pg_name = dim_group.group_name
return pg_name
@staticmethod
def _init_process_groups(
self,
layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
@ -482,18 +284,270 @@ else:
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}"
dim_group_names.append(
DeviceMesh._init_one_process_group( # type: ignore[arg-type]
layout[dim], rank_map, dim_name, backend_override[dim]
backend_cache = None
if self.pg_cache_enabled:
backend_cache = self.get_backend_from_cache(
layout[dim], backend_override[dim][1]
)
if backend_cache is not None:
dim_group_names.append(backend_cache)
else:
dim_group_names.append(
self._init_one_process_group( # type: ignore[arg-type]
layout[dim], dim_name, backend_override[dim]
)
)
if dim_group_names[-1] is not None and self.pg_cache_enabled:
self.update_backend_cache(
layout[dim], dim_group_names[-1], backend_override[dim][1]
)
)
if any(n is None for n in dim_group_names):
assert all(n is None for n in dim_group_names)
return []
return dim_group_names
torch.serialization.add_safe_globals([_SharedState])
class DeviceMesh:
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
represented as a n-d dimension array, and each value of the n-d dimensional
array is the global id of the default process group ranks.
DeviceMesh could be used to setup the N dimensional device connections across the cluster,
and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects
already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization),
and will select/set the device for the current process if user does not set the device
beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
DeviceMesh can also be used as a context manager when using together with DTensor APIs.
.. note::
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
is running on all processes/ranks in the cluster. Therefore, users need to make sure the
`mesh` array (which describes the layout of devices) should be identical across all ranks.
Inconsistent `mesh` will lead to silent hang.
Args:
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
of devices, where the IDs are global IDs of the default process group.
_rank (int): (experimental/internal)
The global rank of the current process. If not provided, it will
be inferred from the default process group.
Returns:
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
The following program runs on each process/rank in an SPMD manner. In this example, we have 2
hosts with 4 GPUs each.
A reduction over the first dimension of mesh will reduce across
columns (0, 4), .. and (3, 7), a reduction over the second dimension
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
Example::
>>> # xdoctest: +SKIP("no rank")
>>> from torch.distributed.device_mesh import DeviceMesh
>>>
>>> # Initialize device mesh as (2, 4) to represent the topology
>>> # of cross-host(dim 0), and within-host (dim 1).
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
"""
# Flag to specify device save without backend info. This is a temporary variable
# We will remove this flag once we fully deprecate the behavior of save a device mesh with pg names.
decouple_backend_at_save = False
_mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout
# Record flatten mesh name to its flattened mesh in root mesh.
_flatten_mapping: dict[str, "DeviceMesh"]
_shared_state: _SharedState
def __init__(
self,
device_type: str,
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
*,
mesh_dim_names: Optional[tuple[str, ...]] = None,
backend_override: Optional[tuple[BackendConfig, ...]] = None,
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_shared_state: Optional[_SharedState] = None,
) -> None:
if mesh is not None:
if _layout is not None or _shared_state is not None:
raise TypeError(
"Cannot provide _layout and/or _shared_state if passing explicit mesh"
)
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
mesh_tensor = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
rank_map = mesh_tensor.flatten()
self._shared_state = _SharedState(
device_type=device_type, rank_map=rank_map, root_mesh=self
)
else:
if _layout is None or _shared_state is None:
raise TypeError(
"The mesh argument is required except for PRIVATE USAGE ONLY!"
)
rank_map = _shared_state.get_rank_map()
self._shared_state = _shared_state
if self._shared_state.get_root_mesh() is None:
self._shared_state._root_mesh = self
if not _layout.check_non_overlap():
raise AssertionError(
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Internal bookkeeping for the device mesh.
self._layout = _layout
assert self._shared_state.get_rank_map().numel() >= self._layout.cosize(), (
f"The rank map contains {rank_map.numel()} element, "
f"which isn't large enough for layout {self._layout}"
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
if backend_override is None:
backend_override = ((None, None),) * len(self._layout)
elif len(backend_override) != len(self._layout):
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {len(self._layout)}."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
if self._layout.numel() != self.mesh.numel():
raise AssertionError(
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._shared_state.get_rank_map().tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not).
if _init_backend:
self._setup_world_group_and_device()
self._dim_group_names = self._shared_state._init_process_groups(
self._layout,
self._mesh_dim_names,
backend_override,
)
if is_initialized() and get_backend() == "threaded":
# pyrefly: ignore [bad-assignment]
self._thread_id = threading.get_ident()
if _rank is None:
_rank = get_rank()
# calculate the coordinates of the current global rank on the mesh
rank_coords = (self.mesh == _rank).nonzero()
if rank_coords.size(0) not in (0, 1):
raise AssertionError(
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
)
self._coordinate_on_dim: Optional[list[int]] = (
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""
return self._shared_state.get_device_type()
@property
def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices."""
full_mesh = self._layout.remap_to_tensor(self._shared_state.get_rank_map())
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
if my_coords.size(0) > 0:
return full_mesh[my_coords[0, 0]]
raise RuntimeError(
"In order to get the mesh Tensor of a DeviceMesh it needs to "
"either have all its original dimensions (e.g., no slicing) "
"or it needs to contain the local rank"
)
@property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
"""Returns the names of mesh dimensions."""
return self._mesh_dim_names
def _setup_world_group_and_device(self):
default_initialized = is_initialized()
# TODO: think about how to allow pg options to be passed to world group
# or mesh dimension groups
if not default_initialized:
init_process_group()
world_size = get_world_size()
if self._layout.numel() > world_size:
raise RuntimeError(
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
)
# ONLY set the device if the current device is not initialized, if user already
# set the device before DeviceMesh init, we respect the user's choice.
device_handle = _get_device_handle(self._shared_state.get_device_type())
if device_handle and not device_handle.is_initialized():
# auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK
# env variable from launchers, we use it to set the device.
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
logger.info(
"Setting default device for the current process based on LOCAL_RANK=%s",
local_rank,
)
device_handle.set_device(local_rank)
else:
warnings.warn(
"It seems like you did not set/select the default device for the current process before the DeviceMesh "
"initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. "
"It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
"the underlying communicator (i.e. NCCL) can be initialized properly. "
"Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
"device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. ",
stacklevel=2,
)
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_devices_per_host = device_handle.device_count()
if (
world_size > num_devices_per_host
and world_size % num_devices_per_host != 0
):
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_devices_per_host} {self._shared_state.get_device_type()} devices!"
)
device_handle.set_device(get_rank() % num_devices_per_host)
return _get_default_group()
def _get_root_mesh(self) -> "DeviceMesh":
return self._root_mesh if self._root_mesh else self
root_mesh = self._shared_state.get_root_mesh()
return root_mesh if root_mesh is not None else self
def __enter__(self) -> "DeviceMesh":
# set this mesh as the current mesh in mesh env
@ -523,9 +577,9 @@ else:
if not self._hash:
self._hash = hash(
(
self._flatten_rank_map,
self._get_universe_id(),
self._layout,
self._device_type,
self._shared_state.get_device_type(),
self._mesh_dim_names,
self._thread_id,
)
@ -538,9 +592,10 @@ else:
if not isinstance(other, DeviceMesh):
return False
return (
self._flatten_rank_map == other._flatten_rank_map
self._get_universe_id() == other._get_universe_id()
and self._layout == other._layout
and self._device_type == other._device_type
and self._shared_state.get_device_type()
== other._shared_state.get_device_type()
and self._mesh_dim_names == other._mesh_dim_names
and self._thread_id == other._thread_id
)
@ -695,12 +750,11 @@ else:
]
)
res_submesh = DeviceMesh(
self._device_type,
self._shared_state.get_device_type(),
_layout=layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=submesh_dim_names,
_root_mesh=root_mesh,
_init_backend=False,
_shared_state=root_mesh._shared_state,
)
res_submesh._dim_group_names = slice_dim_group_name
return res_submesh
@ -745,12 +799,11 @@ else:
)
res_flattened_mesh = DeviceMesh(
root_mesh._device_type,
root_mesh._shared_state.get_device_type(),
_layout=flattened_mesh_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=(mesh_dim_name,),
_root_mesh=root_mesh,
backend_override=(backend_override,),
_shared_state=root_mesh._shared_state,
)
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
@ -874,12 +927,12 @@ else:
"""
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
layout = self._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
pg_ranks_by_dim = layout.remap_to_tensor(self._shared_state.get_rank_map())
cur_rank = self.get_rank()
res_submeshes = []
for mesh_1d in pg_ranks_by_dim:
submesh = DeviceMesh(
self._device_type,
self._shared_state.get_device_type(),
mesh_1d,
mesh_dim_names=(mesh_dim_name,),
_init_backend=False,
@ -1051,6 +1104,12 @@ else:
)
return not_none(get_rank(mesh_dim_group))
def _get_universe_id(self) -> Union[tuple[int, ...], int]:
if self.decouple_backend_at_save:
return id(self._shared_state.get_rank_map())
else:
return self._flatten_rank_map
def get_coordinate(self) -> Optional[list[int]]:
"""
Return the relative indices of this rank relative to all
@ -1118,10 +1177,9 @@ else:
res_mesh = DeviceMesh(
self.device_type,
_layout=unflattened_layout,
_rank_map=root_mesh._rank_map,
mesh_dim_names=tuple(unflattened_mesh_dim_names),
_root_mesh=root_mesh,
_init_backend=False,
_shared_state=root_mesh._shared_state,
)
# If original mesh has initiated its backend, we need to initialize the backend
@ -1130,11 +1188,12 @@ else:
# per dim backend init.
if hasattr(self, "_dim_group_names"):
dim_group_names = self._dim_group_names.copy()
dim_group_names[dim : dim + 1] = self._init_process_groups(
partial_layout,
root_mesh._rank_map,
mesh_dim_names,
backend_override,
dim_group_names[dim : dim + 1] = (
root_mesh._shared_state._init_process_groups(
partial_layout,
mesh_dim_names,
backend_override,
)
)
res_mesh._dim_group_names = dim_group_names
@ -1210,7 +1269,7 @@ else:
concat_sizes: list[IntTuple] = []
concat_strides: list[IntTuple] = []
concat_dim_group_name: list[str] = []
flatten_rank_map = device_mesh_list[0]._flatten_rank_map
mesh_universe_id = device_mesh_list[0]._get_universe_id()
for dm in device_mesh_list:
for i in range(len(dm._layout)):
concat_sizes.append(dm._layout[i].sizes)
@ -1219,7 +1278,7 @@ else:
concat_dim_group_name.extend(not_none(dm._dim_group_names))
# Concatenate device mesh having different root mesh tensors are meaningless
# because the concatenated indices should be indexed by the same root mesh tensor.
if dm._flatten_rank_map != flatten_rank_map:
if dm._get_universe_id() != mesh_universe_id:
raise RuntimeError(
"Cannot concatenate DeviceMeshes derived from different device meshs"
)
@ -1231,14 +1290,109 @@ else:
res_mesh = DeviceMesh(
device_mesh_list[0].device_type,
_layout=concat_mesh_layout,
_rank_map=device_mesh_list[0]._rank_map,
mesh_dim_names=tuple(concat_dim_names),
_root_mesh=device_mesh_list[0]._get_root_mesh(),
_init_backend=False,
_shared_state=device_mesh_list[0]._shared_state,
)
res_mesh._dim_group_names = concat_dim_group_name
return res_mesh
def __getstate__(self):
"""
Returns the state of the DeviceMesh as a dictionary for serialization,
which contains all necessary information to reconstruct the DeviceMesh.
"""
shared_state = {
"device_type": self._shared_state._device_type,
"rank_map": self._shared_state._rank_map,
}
if self._shared_state._root_mesh != self:
shared_state["root_mesh"] = not_none(
self._shared_state._root_mesh
).__getstate__()
state: dict[str, Any] = {
"shared_state": shared_state,
"layout": self._layout,
"mesh_dim_names": self._mesh_dim_names,
"thread_id": self._thread_id,
"coordinate_on_dim": getattr(self, "_coordinate_on_dim", None),
}
# Serialize flatten_mapping
flatten_mapping: dict[str, Any] = {}
for mesh_name, mesh in self._flatten_mapping.items():
flatten_mapping[mesh_name] = mesh.__getstate__()
state["flatten_mapping"] = flatten_mapping
if not self.decouple_backend_at_save and hasattr(self, "_dim_group_names"):
logger.warning(
"Save device mesh via torch.save with pg names and will be deprecated in PT 2.11. "
"Users are welcome to use Distributed checkpoint (DCP) or re-create pgs in the same order"
"as the original device mesh."
)
state["dim_group_names"] = self._dim_group_names
return state
def __setstate__(self, state):
"""
Restores the DeviceMesh state from a state dictionary.
"""
required_keys = {
"shared_state",
"layout",
"mesh_dim_names",
"thread_id",
"coordinate_on_dim",
"flatten_mapping",
}
missing_keys = required_keys - state.keys()
if missing_keys:
raise ValueError(f"state_dict is missing required keys: {missing_keys}")
# Restore shared_state
shared_state = state["shared_state"]
# First, restore root_mesh if it exists (we need to do this before creating _SharedState)
root_mesh = None
if shared_state.get("root_mesh") is not None:
# Create a new DeviceMesh for the root mesh
root_mesh = DeviceMesh.__new__(DeviceMesh)
root_mesh.__setstate__(shared_state["root_mesh"])
# Create and initialize the shared state
self._shared_state = _SharedState(
device_type=shared_state["device_type"],
rank_map=shared_state["rank_map"],
root_mesh=root_mesh,
)
# Restore other attributes
self._layout = state["layout"]
self._mesh_dim_names = state["mesh_dim_names"]
self._thread_id = state["thread_id"]
if state.get("coordinate_on_dim") is not None:
self._coordinate_on_dim = state["coordinate_on_dim"]
# Re-initialize internal bookkeeping
self._flatten_rank_map = tuple(self._shared_state._rank_map.tolist())
# Restore flatten_mapping
self._flatten_mapping = {}
if state.get("flatten_mapping"):
for mesh_name, mesh_state in state["flatten_mapping"].items():
flatten_mesh = DeviceMesh.__new__(DeviceMesh)
flatten_mesh.__setstate__(mesh_state)
self._flatten_mapping[mesh_name] = flatten_mesh
# We don't recommend load from saved pg names, because users need to ensure the same
# order in creating process groups when we save the device mesh.
# This is implicit and error-prone. We will remove this behavior soon.
# What we recommend users to do is to explicitly create PGs and set it to the loaded mesh.
if state.get("dim_group_names"):
self._dim_group_names = state["dim_group_names"]
def _normalize_backend_override(
backend_override: dict[
Union[int, str],
@ -1363,12 +1517,13 @@ else:
# external device type has been set to be (e.g. meta)
with torch.device("cpu"):
rank_map = torch.arange(layout.numel(), dtype=torch.int)
shared_state = _SharedState(device_type=device_type, rank_map=rank_map)
device_mesh = DeviceMesh(
device_type=device_type,
_layout=layout,
_rank_map=rank_map,
mesh_dim_names=mesh_dim_names,
backend_override=backend_override_tuple,
_shared_state=shared_state,
)
return device_mesh