mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[DeviceMesh] Add a warning for slicing flattened dim from root mesh and types for _get_slice_mesh_layout (#164993)
As title, we want to add a deprecate warning for slicing flattened dim from root mesh. Also cosmetic changes for adding types for `_get_slice_mesh_layout`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164993 Approved by: https://github.com/fegin ghstack dependencies: #164750, #164954
This commit is contained in:
@ -239,7 +239,9 @@ else:
|
|||||||
)
|
)
|
||||||
return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))
|
return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))
|
||||||
|
|
||||||
def _get_slice_mesh_layout(self, device_mesh, mesh_dim_names) -> _MeshLayout:
|
def _get_slice_mesh_layout(
|
||||||
|
self, device_mesh: "DeviceMesh", mesh_dim_names: tuple[str, ...]
|
||||||
|
) -> _MeshLayout:
|
||||||
"""
|
"""
|
||||||
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
|
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
|
||||||
If valid, return dim indexes of the slice mesh in the device mesh.
|
If valid, return dim indexes of the slice mesh in the device mesh.
|
||||||
@ -266,7 +268,7 @@ else:
|
|||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
valid_mesh_dim_names = [
|
valid_mesh_dim_names = [
|
||||||
*device_mesh.mesh_dim_names,
|
*not_none(device_mesh.mesh_dim_names),
|
||||||
*flatten_name_to_root_layout,
|
*flatten_name_to_root_layout,
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -281,11 +283,17 @@ else:
|
|||||||
|
|
||||||
layout_sliced = []
|
layout_sliced = []
|
||||||
for name in mesh_dim_names:
|
for name in mesh_dim_names:
|
||||||
if name in device_mesh.mesh_dim_names:
|
if name in not_none(device_mesh.mesh_dim_names):
|
||||||
layout_sliced.append(
|
layout_sliced.append(
|
||||||
device_mesh._layout[device_mesh.mesh_dim_names.index(name)]
|
device_mesh._layout[
|
||||||
|
not_none(device_mesh.mesh_dim_names).index(name)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
elif name in flatten_name_to_root_layout:
|
elif name in flatten_name_to_root_layout:
|
||||||
|
warnings.warn(
|
||||||
|
"Slicing a flattened dim from root mesh will be deprecated in PT 2.11. "
|
||||||
|
"Users need to bookkeep the flattened mesh directly. "
|
||||||
|
)
|
||||||
layout_sliced.append(flatten_name_to_root_layout[name])
|
layout_sliced.append(flatten_name_to_root_layout[name])
|
||||||
|
|
||||||
sliced_sizes = tuple(l.sizes for l in layout_sliced)
|
sliced_sizes = tuple(l.sizes for l in layout_sliced)
|
||||||
|
Reference in New Issue
Block a user