From 7a1ead755f2e2abe8be49a7a0fb88b6b13973147 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 8 Oct 2025 14:56:54 -0700 Subject: [PATCH] [DeviceMesh] Add a warning for slicing flattened dim from root mesh and types for _get_slice_mesh_layout (#164993) As title, we want to add a deprecate warning for slicing flattened dim from root mesh. Also cosmetic changes for adding types for `_get_slice_mesh_layout`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164993 Approved by: https://github.com/fegin ghstack dependencies: #164750, #164954 --- torch/distributed/device_mesh.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index c90dba2220c5..9fef00f5a809 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -239,7 +239,9 @@ else: ) return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) - def _get_slice_mesh_layout(self, device_mesh, mesh_dim_names) -> _MeshLayout: + def _get_slice_mesh_layout( + self, device_mesh: "DeviceMesh", mesh_dim_names: tuple[str, ...] + ) -> _MeshLayout: """ Validate whether the mesh_dim_names is valid for slicing the given device_mesh. If valid, return dim indexes of the slice mesh in the device mesh. @@ -266,7 +268,7 @@ else: else {} ) valid_mesh_dim_names = [ - *device_mesh.mesh_dim_names, + *not_none(device_mesh.mesh_dim_names), *flatten_name_to_root_layout, ] @@ -281,11 +283,17 @@ else: layout_sliced = [] for name in mesh_dim_names: - if name in device_mesh.mesh_dim_names: + if name in not_none(device_mesh.mesh_dim_names): layout_sliced.append( - device_mesh._layout[device_mesh.mesh_dim_names.index(name)] + device_mesh._layout[ + not_none(device_mesh.mesh_dim_names).index(name) + ] ) elif name in flatten_name_to_root_layout: + warnings.warn( + "Slicing a flattened dim from root mesh will be deprecated in PT 2.11. " + "Users need to bookkeep the flattened mesh directly. " + ) layout_sliced.append(flatten_name_to_root_layout[name]) sliced_sizes = tuple(l.sizes for l in layout_sliced)