Compare commits

..

67 Commits

Author SHA1 Message Date
2cc9ff2f6a Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:29:48 -07:00
6d41a72865 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 16:27:09 -07:00
895795f07c [ROCm][CI] forward fix kineto submodule bump (#166421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166421
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-28 17:40:23 +00:00
150eac735f Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 10:30:23 -07:00
2dc56456cb refactor: pull _replace_node common functionality out of Scheduler.finalize_multi_template_buffers (#163368)
Pull replace_node function out of Scheduler.finalize_multi_template_buffers(). This is needed by the next PR (#163369). As part of this also pull the _replace_operation_buffer() up to top-level since it needed no self references.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163368
Approved by: https://github.com/PaulZhang12
2025-10-28 17:21:52 +00:00
2ea430ce25 Update base for Update on "[For discussion][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map"
We want to create a shared_state to store root_mesh, rank_map and pg caches. We can add more into it down the road, so that it becomes a singleton for bookkeeping and also align with our original proposal to move toward the idea of mesh universe.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-27 16:22:44 -07:00
b9a6e020a2 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 10:22:06 -07:00
310db76d36 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-21 10:22:06 -07:00
4a70f2d033 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 15:41:23 -07:00
82fc673005 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 15:41:23 -07:00
e587947b20 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:55:15 -07:00
ff0ebf7fe5 Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:55:15 -07:00
7a85a3d289 Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 11:39:19 -07:00
baf793c54d Update base for Update on "[DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh"
With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 11:39:19 -07:00
e0c47235a3 [DeviceMesh][2D] Use concatenate for 2D (FSDP+TP) instead of getting from root mesh
[ghstack-poisoned]
2025-10-14 17:13:09 -07:00
f6b4fe1c64 Update on "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:11:13 -07:00
092d3778a3 Update base for Update on "[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:11:13 -07:00
82b416f4db Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:56:12 -07:00
36233be9d3 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:56:12 -07:00
f56e3c8fdf Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:25:44 -07:00
ad324d0e3c Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-14 11:25:44 -07:00
a4dc3dafee Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-12 21:41:06 -07:00
766493267e Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-12 21:41:06 -07:00
bd086ac1b3 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-01 15:35:10 -07:00
411a5c7f7f Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-10-01 15:35:10 -07:00
3bab82c453 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-30 16:06:59 -07:00
11a624dc28 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-30 16:06:59 -07:00
4736bb57e3 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-26 22:14:44 -07:00
c2aaa5664c Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-26 22:14:44 -07:00
569a9000a5 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-25 16:17:40 -07:00
73f13abc50 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-25 16:17:40 -07:00
3038d7f285 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-24 14:40:29 -07:00
57448253f3 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-24 14:40:29 -07:00
32caf41d72 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 21:09:56 -07:00
28185f4406 Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 21:09:56 -07:00
92c709c202 Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 14:41:26 -07:00
f094af1e1a Update base for Update on "[For Discussion][DeviceMesh] Implement a concatenate api for submesh"
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
2025-09-19 14:41:26 -07:00
09db0ef757 [For Discussion][DeviceMesh] Implement a concatenate api for submesh
[ghstack-poisoned]
2025-09-19 11:34:56 -07:00
ecd3f525d5 Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-18 21:00:54 -07:00
b2345c972f Update base for Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-18 21:00:54 -07:00
fe44a87ed4 Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-17 22:50:00 -07:00
f0aa9cfc42 Update base for Update on "[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-09-17 22:50:00 -07:00
0ea286d26d Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:37:29 -07:00
621b9f2be8 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:37:29 -07:00
5ac6d410aa Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:32:59 -07:00
2bb8d19968 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 14:32:59 -07:00
2cd038d95d Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:59:38 -07:00
5decc0e164 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:59:38 -07:00
bafbc39603 Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:00:28 -07:00
356abd0719 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-26 09:00:28 -07:00
457bdbfaa4 Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-25 11:23:34 -07:00
5196ba3db4 Update base for Update on "[WIP][device_mesh] Implement _unflatten on top of CuTe layout bookkeeping"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-25 11:23:34 -07:00
bf108cf3c9 [WIP][device_mesh] Implement on top of CuTe layout bookkeeping
[ghstack-poisoned]
2025-08-21 16:50:56 -07:00
389519a03c Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:20:20 -07:00
873ec8442e Update base for Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:20:20 -07:00
5841ede067 Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:13:33 -07:00
4c90367bfb Update base for Update on "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to refactor the internal bookkeeping of DeviceMesh so that:
1. Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout)
2. Separate backend from the mesh operations so that we eventually can let users do lots of operations without initializing any backend.


Concretely, in this PR, we do the following:
1. Replaced all index/offset and its mappings with CuTe Layout and a backend class which handles all the bookkeeping and create backend if needed. Use CuTe layout for both slicing and _flatten.
2. We also started to make devicemesh more functional (first from the backend perspective). Each newly created device mesh is like a universe, all devicemesh transformed out from it (slicing, flatten, unflatten, etc) will share the same backend (PG) while creating a new device mesh will be different universe. So we changed our unit tests accordingly as well.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
2025-08-21 13:13:32 -07:00
27910fb22a Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
We want to implement the 


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 11:46:45 -07:00
e5cbce1780 Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 11:03:28 -07:00
4dd9c8cf2d Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-21 08:55:35 -07:00
8c0643a671 Update on "[WIP][DeviceMesh] Simplifying internal bookkeeping with CuTe layout"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 22:34:22 -07:00
0f7736dd55 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 16:02:43 -07:00
9e8f81aaa8 [DeviceMesh] Simplifying internal bookkeeping with CuTe layout
[ghstack-poisoned]
2025-08-20 16:02:43 -07:00
72706c7cb9 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-20 08:31:24 -07:00
01a3a8550a Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-19 22:42:05 -07:00
00e03cccf3 Update on "[DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping"
DeviceMesh essentially is a way to specify how devices interact with each other or device layout. They are all integers but because they can have various shapes and meshes, it make internal bookkeeping internally way more challenging. Currently our internal bookkeeing inside DeviceMesh is not scalable, so in order to support new functions like `_unflatten`, we need to introduce very complicated logics inside DeviceMesh as pointed out per comment (https://github.com/pytorch/pytorch/pull/159482/files#r2256025452). So thanks to lw 's suggestion and PoC PR (https://github.com/pytorch/pytorch/pull/160429), we realize that by leveraging CuTe layout algebra([ref](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html)) from Cutlass will greatly simply our internal mechanical bookkeeping for and make the abstraction ops way easier on top of it. So to make things go incrementally, we propose couple steps here https://github.com/pytorch/pytorch/issues/160337#issuecomment-3195106243 and this PR is step 0.

We only bring in the layout class as a private class and added detailed explanations and comments (thanks to llm) and unit test to show case the code indeed is working as expected. The only big change we make here is that instead of assuming smallest stride dimension from left to right inside coalesce and complement, we follow the order of PyTorch which we keep the inner-most dimension to the rightmost.

More PRs are on the way.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
2025-08-19 21:27:37 -07:00
1a58e8dad5 [DeviceMesh] Introduce CuTe layout into devicemesh code base for internal bookkeeping
[ghstack-poisoned]
2025-08-19 15:57:40 -07:00
5 changed files with 188 additions and 172 deletions

View File

@ -892,10 +892,16 @@ fn(torch.randn(5))
os.remove(
file_path
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
empty_line_normalizer(lines),
empty_line_normalizer(stderr.decode("utf-8")),
)
orig_maxDiff = unittest.TestCase.maxDiff
unittest.TestCase.maxDiff = None
try:
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
empty_line_normalizer(lines),
empty_line_normalizer(stderr.decode("utf-8")),
)
except Exception:
unittest.TestCase.maxDiff = orig_maxDiff
raise
@make_settings_test("torch._dynamo.eval_frame")
def test_log_traced_frames(self, records):

View File

@ -529,7 +529,7 @@ class TestProfiler(TestCase):
found_mm = True
if "gemm" in e.name.lower() or "Cijk" in e.name:
found_gemm = True
if "memcpy" in e.name.lower():
if "memcpy" in e.name.lower() or "__amd_rocclr_copyBuffer" in e.name:
found_memcpy = True
if use_cuda:
self.assertTrue(found_gemm)

View File

@ -445,7 +445,7 @@ use_numpy_random_stream = False
enable_cpp_guard_manager = True
# Use C++ guard manager for symbolic shapes
enable_cpp_symbolic_shape_guards = False
enable_cpp_symbolic_shape_guards = not is_fbcode()
# Enable tracing through contextlib.contextmanager
enable_trace_contextlib = True

View File

@ -409,9 +409,10 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
class BaseSchedulerNode:
ancestors: OrderedSet[str]
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
last_usage: OrderedSet[str]
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
@ -420,22 +421,24 @@ class BaseSchedulerNode:
min_order: int
max_order: int
mpi_node: MemoryPlanningInfoForNode
mutation_renames: dict[str, str]
node: Optional[ir.Operation]
outputs: list[SchedulerBuffer]
outputs_by_name: dict[str, SchedulerBuffer]
override_estimated_runtime: Optional[float] = None
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
lambda *args, **kwargs: []
)
self.scheduler = scheduler
self.debug_device_str = lambda *args, **kwargs: []
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
self.ancestors: OrderedSet[str] = OrderedSet()
self.last_usage = OrderedSet[
str
]() # buffers that won't be used after this kernel
self.node = node
self.ancestors = OrderedSet()
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
self.written = False
self.outputs: list[SchedulerBuffer] = [
self.outputs = [
SchedulerBuffer(
scheduler=self.scheduler,
node=output,
@ -443,16 +446,14 @@ class BaseSchedulerNode:
)
for output in node.get_outputs()
]
self.outputs_by_name: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for buf in self.outputs
}
self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
# mutation_renames for the current node. Due to potential
# more mutations happening later, this can be different
# to Scheduler.mutation_renames. Also this dict should be small
# since only mutation information relevant to the deps for this
# node is stored here.
self.mutation_renames: dict[str, str] = {}
self.mutation_renames = {}
def __repr__(self) -> str:
return f"{type(self).__name__}(name={self.get_name()!r})"
@ -2435,6 +2436,34 @@ def pick_loop_order(
return order
def _replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
@dataclasses.dataclass
class NodeUser:
node: Union[BaseSchedulerNode, OutputNode]
@ -3336,33 +3365,6 @@ class Scheduler:
will force completion of compilation and benchmarking.
"""
def replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
for i, node in enumerate(self.nodes):
if isinstance(node, SchedulerNode) and isinstance(
node.node, ir.MultiTemplateBuffer
@ -3416,40 +3418,47 @@ class Scheduler:
assign_origin_node(out_tensorbox, multi_node.origin_node)
out_buffer.layout = multi_node.layout
replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
self._replace_node(out_buffer, multi_node, i, node)
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def _replace_node(
self,
out_buffer: ir.OperationBuffer,
multi_node: ir.MultiTemplateBuffer,
i: int,
node: SchedulerNode,
) -> None:
_replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(
node.read_writes.reads, node.unmet_dependencies
):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
return any(

View File

@ -350,22 +350,33 @@ else:
return _get_default_group()
@staticmethod
def _init_process_groups(
layout: _MeshLayout,
def _init_one_process_group(
sub_layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
#
dim_group_names: list[str] = []
dim_name: str,
backend_override: BackendConfig,
) -> Optional[str]:
# Generate a 2D global mesh tensor for the current dim for PG creation.
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
backend, pg_options = backend_override
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
# TODO: remove this once we have fixed inside c10d level.
timeout = pg_options._timeout if pg_options else None
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
# If the mesh doesn't have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = f"mesh_{dim_name}"
dim_group = None
default_group = _get_default_group()
if (
len(layout) == 1
and layout.numel() == get_world_size()
and backend_override[0] == (None, None)
# Early return if there is only one sub_layout in the mesh layout.
if sub_layout.numel() == get_world_size() and backend_override == (
None,
None,
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
# Otherwise, create new pg.
@ -380,90 +391,80 @@ else:
and get_backend(default_group) == "gloo"
else default_group
)
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
# TODO: remove this once we have fixed inside c10d level.
timeout = pg_options._timeout if pg_options else None
return dim_group.group_name # type: ignore[union-attr]
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = (
f"mesh_{mesh_dim_names[dim]}"
if mesh_dim_names
else f"mesh_dim_{dim}"
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
# for each mesh dimension. In a 2 * 4 mesh, we only need to make two API calls per ranks to create
# all the subgroups.
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
# mesh, we need to make two API calls per ranks to create all the subgroups.
if (
getattr(default_group, "bound_device_id", None) is not None
and torch.cuda.is_available()
and (
backend is None
or default_group._get_backend(torch.device("cuda")).name()
== backend
)
):
dim_group = split_group(
parent_pg=default_group,
timeout=timeout,
pg_options=pg_options,
split_ranks=pg_ranks_by_dim.tolist(),
group_desc=group_desc,
)
return dim_group.group_name # type: ignore[union-attr]
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
# along with appending information to the `dim_group_names` list whenever necessary.
pg_name = None
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
dim_group = new_group(
ranks=subgroup_ranks,
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
# only add to dim_groups if the current rank in the subgroup
if get_rank() in subgroup_ranks:
if pg_name is not None:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"in {subgroup_ranks}!"
)
pg_name = dim_group.group_name
return pg_name
@staticmethod
def _init_process_groups(
layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
dim_group_names: list[str] = []
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}"
dim_group_names.append(
DeviceMesh._init_one_process_group( # type: ignore[arg-type]
layout[dim], rank_map, dim_name, backend_override[dim]
)
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
# for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create
# all the subgroups.
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
# mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
dim_group = None
has_split_group = False
if (
(
bound_device_id := getattr(
default_group, "bound_device_id", None
)
)
is not None
and torch.cuda.is_available()
and (
backend is None
or default_group._get_backend(torch.device("cuda")).name()
== backend
)
):
dim_group = split_group(
parent_pg=default_group,
timeout=timeout,
pg_options=pg_options,
split_ranks=pg_ranks_by_dim.tolist(),
group_desc=group_desc,
)
has_split_group = True
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
# along with appending information to the `dim_group_names` list whenever necessary.
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
# Temporarily reverting to resolve test timeout while root-causing.
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
# pyrefly: ignore [unbound-name]
if bound_device_id is None or not has_split_group:
dim_group = new_group(
ranks=subgroup_ranks,
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
# only add to dim_groups if the current rank in the subgroup
if get_rank() in subgroup_ranks:
if len(dim_group_names) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"in {subgroup_ranks}!"
)
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
)
if any(n is None for n in dim_group_names):
assert all(n is None for n in dim_group_names)
return []
return dim_group_names
def _get_root_mesh(self) -> "DeviceMesh":