[DeviceMesh] Add a warning for slicing flattened dim from root mesh and types for _get_slice_mesh_layout (#164993)

As title, we want to add a deprecate warning for slicing flattened dim from root mesh. Also cosmetic changes for adding types for `_get_slice_mesh_layout`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164993
Approved by: https://github.com/fegin
ghstack dependencies: #164750, #164954
This commit is contained in:
fduwjj
2025-10-08 14:56:54 -07:00
committed by PyTorch MergeBot
parent 90b4e130d6
commit 7a1ead755f

View File

@ -239,7 +239,9 @@ else:
) )
return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))
def _get_slice_mesh_layout(self, device_mesh, mesh_dim_names) -> _MeshLayout: def _get_slice_mesh_layout(
self, device_mesh: "DeviceMesh", mesh_dim_names: tuple[str, ...]
) -> _MeshLayout:
""" """
Validate whether the mesh_dim_names is valid for slicing the given device_mesh. Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
If valid, return dim indexes of the slice mesh in the device mesh. If valid, return dim indexes of the slice mesh in the device mesh.
@ -266,7 +268,7 @@ else:
else {} else {}
) )
valid_mesh_dim_names = [ valid_mesh_dim_names = [
*device_mesh.mesh_dim_names, *not_none(device_mesh.mesh_dim_names),
*flatten_name_to_root_layout, *flatten_name_to_root_layout,
] ]
@ -281,11 +283,17 @@ else:
layout_sliced = [] layout_sliced = []
for name in mesh_dim_names: for name in mesh_dim_names:
if name in device_mesh.mesh_dim_names: if name in not_none(device_mesh.mesh_dim_names):
layout_sliced.append( layout_sliced.append(
device_mesh._layout[device_mesh.mesh_dim_names.index(name)] device_mesh._layout[
not_none(device_mesh.mesh_dim_names).index(name)
]
) )
elif name in flatten_name_to_root_layout: elif name in flatten_name_to_root_layout:
warnings.warn(
"Slicing a flattened dim from root mesh will be deprecated in PT 2.11. "
"Users need to bookkeep the flattened mesh directly. "
)
layout_sliced.append(flatten_name_to_root_layout[name]) layout_sliced.append(flatten_name_to_root_layout[name])
sliced_sizes = tuple(l.sizes for l in layout_sliced) sliced_sizes = tuple(l.sizes for l in layout_sliced)