mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)
We want to refactor the internal bookkeeping of DeviceMesh so that: Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible. Concretely, in this PR, we do the following: 1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims. 2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh. 3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`. 4. Use the newly added function `check_overlap` to check layout overlap. 5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1). Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match. <img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" /> The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor. With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout. This is a continue of https://github.com/pytorch/pytorch/pull/161106 (original one got messed with EasyCLA) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163213 Approved by: https://github.com/lw, https://github.com/fegin
This commit is contained in:
@ -6,12 +6,13 @@ import os
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from functools import reduce
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import is_available
|
||||
from torch.distributed._mesh_layout import _MeshLayout
|
||||
from torch.distributed._pycute import is_int
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
@ -66,17 +67,16 @@ else:
|
||||
)
|
||||
|
||||
BackendConfig = tuple[Optional[str], Optional[C10dBackend.Options]]
|
||||
torch.serialization.add_safe_globals([_MeshLayout])
|
||||
|
||||
class _MeshEnv(threading.local):
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: list[DeviceMesh] = []
|
||||
# TODO: Move the bookkeeping maps from _MeshEnv to DeviceMesh.
|
||||
self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {}
|
||||
self.mesh_dim_group_options: dict[int, BackendConfig] = {}
|
||||
# Record flatten mesh name to its flattened mesh in root mesh.
|
||||
self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {}
|
||||
# Record flatten mesh name to its mesh dim index in root mesh.
|
||||
self.flatten_name_to_root_dims: dict[
|
||||
DeviceMesh, dict[str, tuple[int, ...]]
|
||||
] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
@ -86,108 +86,51 @@ else:
|
||||
def create_sub_mesh(
|
||||
self,
|
||||
device_mesh: "DeviceMesh",
|
||||
layout: _MeshLayout,
|
||||
submesh_dim_names: tuple[str, ...],
|
||||
submesh_dims: list[tuple[int, ...]],
|
||||
) -> "DeviceMesh":
|
||||
# Get the submesh dim size from the submesh_dims.
|
||||
# For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want
|
||||
# to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2].
|
||||
# If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2].
|
||||
slice_dim_size = [
|
||||
reduce(
|
||||
lambda x, y: x * device_mesh.mesh.size(y),
|
||||
mesh_dim,
|
||||
1,
|
||||
)
|
||||
for mesh_dim in submesh_dims
|
||||
]
|
||||
|
||||
mesh_tensor = device_mesh.mesh
|
||||
# slice_dim_idx could be different from submesh_dims, as we may need to flatten out some dims.
|
||||
slice_dim_idx = []
|
||||
root_mesh = self.get_root_mesh(device_mesh)
|
||||
slice_dim_group_name = []
|
||||
# keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
|
||||
# flattened mesh tensor.
|
||||
num_dims_flatten = 0
|
||||
for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names):
|
||||
# Currently, this only allows slicing out a contiguous flattened dim.
|
||||
# TODO: we need to handle reconstructing a non-contiguous flattened dim.
|
||||
if len(mesh_dim_indices) > 1:
|
||||
# We need to move the start_dim and end_dim to the left if some dims are already flattened.
|
||||
mesh_tensor = mesh_tensor.flatten(
|
||||
start_dim=mesh_dim_indices[0] - num_dims_flatten,
|
||||
end_dim=mesh_dim_indices[-1] - num_dims_flatten,
|
||||
)
|
||||
# If some dims are already flattened, we need to adjust the slice_dim_idx accordingly.
|
||||
# For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened,
|
||||
# then the final slice_dim_idx should be [0, 1, 2].
|
||||
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
|
||||
num_dims_flatten += len(mesh_dim_indices) - 1
|
||||
for name in submesh_dim_names:
|
||||
if name in not_none(device_mesh.mesh_dim_names):
|
||||
slice_dim_group_name.append(
|
||||
self.root_to_flatten_mapping[device_mesh][
|
||||
mesh_dim_name
|
||||
]._dim_group_names[0] # type: ignore[has-type]
|
||||
device_mesh._dim_group_names[ # type: ignore[has-type]
|
||||
not_none(device_mesh.mesh_dim_names).index(name)
|
||||
]
|
||||
)
|
||||
else:
|
||||
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
|
||||
# If device_mesh is not root_mesh, we already throw error in _get_slice_mesh_layout
|
||||
# Since we will deprecate the slicing of flattened dim_name from root mesh soon,
|
||||
# we don't want to optimize the code furthermore.
|
||||
flatten_mesh = self.root_to_flatten_mapping[device_mesh][name]
|
||||
slice_dim_group_name.append(
|
||||
device_mesh._dim_group_names[mesh_dim_indices[0]] # type: ignore[has-type]
|
||||
flatten_mesh._dim_group_names[ # type: ignore[has-type]
|
||||
not_none(flatten_mesh.mesh_dim_names).index(name)
|
||||
]
|
||||
)
|
||||
|
||||
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
|
||||
mesh_dims_remained_idx = list(range(mesh_tensor.ndim))
|
||||
for idx in slice_dim_idx:
|
||||
if idx not in mesh_dims_remained_idx:
|
||||
raise NotImplementedError(
|
||||
"Currently, this only allows slicing out a contiguous flattened dim."
|
||||
)
|
||||
mesh_dims_remained_idx.remove(idx)
|
||||
|
||||
# pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]
|
||||
# This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with
|
||||
# the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank.
|
||||
pg_ranks_by_dim = mesh_tensor.permute(
|
||||
*mesh_dims_remained_idx, *slice_dim_idx
|
||||
).reshape(-1, *slice_dim_size)
|
||||
|
||||
cur_rank = device_mesh.get_rank()
|
||||
for mesh_nd in pg_ranks_by_dim:
|
||||
submesh = DeviceMesh(
|
||||
device_mesh.device_type,
|
||||
mesh_nd,
|
||||
mesh_dim_names=submesh_dim_names,
|
||||
_init_backend=False,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_submesh = submesh
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||
device_mesh.device_type,
|
||||
pg_ranks_by_dim,
|
||||
cur_rank,
|
||||
submesh_dim_names,
|
||||
_init_backend=False,
|
||||
_layout=layout,
|
||||
)
|
||||
|
||||
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type]
|
||||
self.child_to_root_mapping[res_submesh] = device_mesh
|
||||
|
||||
res_submesh._dim_group_names = slice_dim_group_name
|
||||
self.child_to_root_mapping[res_submesh] = root_mesh
|
||||
return res_submesh
|
||||
|
||||
def create_flatten_mesh(
|
||||
self,
|
||||
device_mesh: "DeviceMesh",
|
||||
mesh_dim_name: Optional[str] = None,
|
||||
backend_override: BackendConfig = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
backend_override: BackendConfig = (None, None),
|
||||
) -> "DeviceMesh":
|
||||
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
||||
|
||||
flatten_dims_in_root = [
|
||||
not_none(root_mesh.mesh_dim_names).index(flatten_mesh_dim_name)
|
||||
for flatten_mesh_dim_name in not_none(device_mesh.mesh_dim_names)
|
||||
]
|
||||
root_mesh = self.get_root_mesh(device_mesh)
|
||||
|
||||
if not mesh_dim_name:
|
||||
mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names))
|
||||
@ -199,7 +142,6 @@ else:
|
||||
return device_mesh
|
||||
|
||||
# Check whether the mesh_dim_name for flattened mesh is valid.
|
||||
self.flatten_name_to_root_dims.setdefault(root_mesh, {})
|
||||
invalid_dim_names = not_none(root_mesh.mesh_dim_names)
|
||||
if mesh_dim_name in invalid_dim_names:
|
||||
raise ValueError(
|
||||
@ -208,47 +150,43 @@ else:
|
||||
f"Please specify another valid mesh_dim_name.",
|
||||
)
|
||||
|
||||
flattened_mesh_layout = device_mesh._layout.coalesce()
|
||||
# Quick return if the flatten mesh has been created before.
|
||||
if (
|
||||
root_mesh in self.root_to_flatten_mapping
|
||||
and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
|
||||
):
|
||||
if (
|
||||
tuple(flatten_dims_in_root)
|
||||
== self.flatten_name_to_root_dims[root_mesh][mesh_dim_name]
|
||||
flattened_mesh_layout
|
||||
== self.root_to_flatten_mapping[root_mesh][mesh_dim_name]._layout
|
||||
):
|
||||
return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, "
|
||||
f"Please specify another valid mesh_dim_name.",
|
||||
f"Please specify another valid mesh_dim_name."
|
||||
)
|
||||
|
||||
flattened_mesh_dim_size = math.prod(device_mesh.mesh.size())
|
||||
|
||||
remained_dims_in_root = list(range(root_mesh.mesh.ndim))
|
||||
for flatten_dim_in_root in flatten_dims_in_root:
|
||||
remained_dims_in_root.remove(flatten_dim_in_root)
|
||||
|
||||
pg_ranks_by_dim = root_mesh.mesh.permute(
|
||||
*remained_dims_in_root, *flatten_dims_in_root
|
||||
).reshape(-1, flattened_mesh_dim_size)
|
||||
|
||||
cur_rank = root_mesh.get_rank()
|
||||
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
|
||||
# new_group api to avoid potential hang.
|
||||
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
|
||||
root_mesh.mesh,
|
||||
)
|
||||
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||
root_mesh.device_type,
|
||||
pg_ranks_by_dim,
|
||||
pg_ranks_by_dim.flatten(
|
||||
start_dim=1
|
||||
), # this is needed for flatten non-contiguous mesh dims.
|
||||
cur_rank,
|
||||
(mesh_dim_name,),
|
||||
(backend_override,),
|
||||
_layout=device_mesh._layout.coalesce(),
|
||||
)
|
||||
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
|
||||
self.child_to_root_mapping[res_flattened_mesh] = root_mesh
|
||||
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
|
||||
res_flattened_mesh # type: ignore[possibly-undefined]
|
||||
res_flattened_mesh
|
||||
)
|
||||
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
|
||||
flatten_dims_in_root
|
||||
) # type: ignore[possibly-undefined]
|
||||
|
||||
return res_flattened_mesh
|
||||
|
||||
@ -310,27 +248,35 @@ else:
|
||||
) -> None:
|
||||
self.mesh_dim_group_options[dim] = (backend, pg_options)
|
||||
|
||||
def _get_slice_mesh_dims(
|
||||
self, device_mesh, mesh_dim_names
|
||||
) -> list[tuple[int, ...]]:
|
||||
def _get_slice_mesh_layout(self, device_mesh, mesh_dim_names) -> _MeshLayout:
|
||||
"""
|
||||
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
|
||||
If valid, return dim indexes of the slice mesh in the device mesh.
|
||||
"""
|
||||
slice_from_root = True
|
||||
if device_mesh != self.get_root_mesh(device_mesh):
|
||||
warnings.warn(
|
||||
"You are attempting to slice a submesh from another submesh. While we support this operation, "
|
||||
"it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
|
||||
"If not, this may result in some ranks receiving the submesh while others encounter errors."
|
||||
)
|
||||
slice_from_root = False
|
||||
|
||||
# The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names
|
||||
# or its flattened mesh's mesh_dim_names.
|
||||
self.flatten_name_to_root_dims.setdefault(device_mesh, {})
|
||||
flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh]
|
||||
# or its flattened mesh's mesh_dim_names if it's root_mesh.
|
||||
flatten_name_to_root_layout = (
|
||||
{
|
||||
key: mesh._layout
|
||||
for key, mesh in self.root_to_flatten_mapping.setdefault(
|
||||
device_mesh, {}
|
||||
).items()
|
||||
}
|
||||
if slice_from_root
|
||||
else {}
|
||||
)
|
||||
valid_mesh_dim_names = [
|
||||
*device_mesh.mesh_dim_names,
|
||||
*flatten_name_to_root_dims,
|
||||
*flatten_name_to_root_layout,
|
||||
]
|
||||
|
||||
if not all(
|
||||
@ -342,30 +288,51 @@ else:
|
||||
f"Valid mesh_dim_names are {valid_mesh_dim_names}."
|
||||
)
|
||||
|
||||
layout_sliced = []
|
||||
for name in mesh_dim_names:
|
||||
if name in device_mesh.mesh_dim_names:
|
||||
layout_sliced.append(
|
||||
device_mesh._layout[device_mesh.mesh_dim_names.index(name)]
|
||||
)
|
||||
elif name in flatten_name_to_root_layout:
|
||||
layout_sliced.append(flatten_name_to_root_layout[name])
|
||||
|
||||
sliced_sizes = tuple(l.sizes for l in layout_sliced)
|
||||
sliced_strides = tuple(l.strides for l in layout_sliced)
|
||||
|
||||
# The check below is from DeviceMesh's implementation before adopting CuTe layout for internal
|
||||
# bookkeeping and it can be removed but we need to define what is the expected behavior.
|
||||
# TODO: Remove the below check and define the expected behavior.
|
||||
# Validate the order of the slice mesh dim indices.
|
||||
# This needs to be in ascending order.
|
||||
curr_idx = -1
|
||||
slice_mesh_dims = []
|
||||
for mesh_dim_name in mesh_dim_names:
|
||||
if mesh_dim_name in flatten_name_to_root_dims:
|
||||
mesh_indices = flatten_name_to_root_dims[mesh_dim_name]
|
||||
# TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx
|
||||
# should be mesh_indices[0] once we support non-contiguous slicing with flatten dim.
|
||||
next_idx = mesh_indices[-1]
|
||||
slice_mesh_dims.append(mesh_indices)
|
||||
else:
|
||||
next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name)
|
||||
slice_mesh_dims.append((next_idx,))
|
||||
if next_idx <= curr_idx:
|
||||
pre_stride = -1
|
||||
for stride in reversed(sliced_strides):
|
||||
# Note that with CuTe layout, we can support slicing flattened non-contiguous mesh dims with no problem.
|
||||
# But this will make this behavior complicated so we decided to not support it for now.
|
||||
if not is_int(stride):
|
||||
raise NotImplementedError(
|
||||
"Currently, this only allows slicing out a contiguous flattened dim."
|
||||
)
|
||||
if stride < pre_stride:
|
||||
raise KeyError(
|
||||
f"Invalid mesh_dim_names {mesh_dim_names} specified. "
|
||||
f"Found mesh dim indices to slice: {slice_mesh_dims}. "
|
||||
"Mesh dim indices should be in ascending order."
|
||||
)
|
||||
curr_idx = next_idx
|
||||
pre_stride = stride
|
||||
|
||||
return slice_mesh_dims
|
||||
# When users sliced dim_names outside from current mesh, we will check whether
|
||||
# there is layout overlap.
|
||||
# TODO: Eventually we will just directly throw error here because
|
||||
# we will deprecate the slicing of flattened dim_name from root mesh.
|
||||
layout_sliced = _MeshLayout(sliced_sizes, sliced_strides)
|
||||
if not layout_sliced.check_non_overlap():
|
||||
raise RuntimeError(
|
||||
f"Slicing overlapping dim_names {mesh_dim_names} is not allowed."
|
||||
)
|
||||
|
||||
return layout_sliced
|
||||
|
||||
# TODO: to make this use case by other components public API in the future.
|
||||
def _get_all_submeshes(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim_name: str
|
||||
) -> list["DeviceMesh"]:
|
||||
@ -373,10 +340,10 @@ else:
|
||||
Return all the submeshes of a given mesh dimension of the device mesh.
|
||||
"""
|
||||
mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name)
|
||||
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
|
||||
-1, device_mesh.mesh.size(mesh_dim)
|
||||
layout = device_mesh._layout[mesh_dim]
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
device_mesh.mesh,
|
||||
)
|
||||
|
||||
cur_rank = device_mesh.get_rank()
|
||||
res_submeshes = []
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
@ -386,7 +353,7 @@ else:
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_backend=False,
|
||||
)
|
||||
submesh._dim_group_names = (
|
||||
submesh._dim_group_names = ( # type: ignore[has-type]
|
||||
[device_mesh._dim_group_names[mesh_dim]] # type: ignore[has-type]
|
||||
if cur_rank in mesh_1d
|
||||
else []
|
||||
@ -454,9 +421,11 @@ else:
|
||||
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
||||
"""
|
||||
|
||||
# TODO: to make existing public fields private and add some methods/properties for bc.
|
||||
device_type: str
|
||||
mesh: torch.Tensor
|
||||
mesh_dim_names: Optional[tuple[str, ...]]
|
||||
_layout: _MeshLayout
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -467,18 +436,33 @@ else:
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
) -> None:
|
||||
self.device_type = device_type
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
self.mesh = (
|
||||
mesh.detach().to(dtype=torch.int)
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * self.mesh.ndim
|
||||
# Internal bookkeeping for the device mesh.
|
||||
self._layout = (
|
||||
_layout
|
||||
if _layout
|
||||
else _MeshLayout(self.mesh.size(), self.mesh.stride())
|
||||
)
|
||||
assert self._layout.check_non_overlap(), (
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
assert self._layout.numel() == self.mesh.numel(), (
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
@ -788,7 +772,7 @@ else:
|
||||
if mesh_dim_names == self.mesh_dim_names:
|
||||
return self
|
||||
else:
|
||||
slice_mesh_dims = _mesh_resources._get_slice_mesh_dims(
|
||||
sliced_mesh_layout = _mesh_resources._get_slice_mesh_layout(
|
||||
self, mesh_dim_names
|
||||
)
|
||||
# When using FakeTensorMode to trace the model, `create_sub_mesh()` will
|
||||
@ -802,7 +786,7 @@ else:
|
||||
# TODO: compiler + device_mesh slicing.
|
||||
with torch._subclasses.fake_tensor.unset_fake_temporarily():
|
||||
submesh = _mesh_resources.create_sub_mesh(
|
||||
self, mesh_dim_names, slice_mesh_dims
|
||||
self, sliced_mesh_layout, mesh_dim_names
|
||||
)
|
||||
return submesh
|
||||
|
||||
@ -868,6 +852,7 @@ else:
|
||||
mesh_dim_names: tuple[str, ...],
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
) -> "DeviceMesh":
|
||||
"""
|
||||
Helper method to create a DeviceMesh from tensor `pg_ranks_by_dim`. This is due to
|
||||
@ -899,6 +884,7 @@ else:
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override,
|
||||
_init_backend=_init_backend,
|
||||
_layout=_layout,
|
||||
)
|
||||
if cur_rank in mesh_nd:
|
||||
res_mesh = mesh
|
||||
@ -984,6 +970,10 @@ else:
|
||||
raise ValueError(
|
||||
"Must pass mesh_dim_names if passing multiple ProcessGroups"
|
||||
)
|
||||
# When init a DeviceMesh with multiple ProcessGroups directly, we need to make sure
|
||||
# the mesh tensor is contiguous. Otherwise, the layout we inferred from the mesh tensor
|
||||
# will have larger span than the actual tensor. This is just internal implementation detail
|
||||
# and does not affect user facing behavior.
|
||||
mesh = (
|
||||
mesh.detach().to(dtype=torch.int, device="cpu")
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
|
Reference in New Issue
Block a user