Compare commits

...

4 Commits

Author SHA1 Message Date
61f59966c7 Update on "[Device Mesh] Add an option to decouple PGs when it comes device mesh save"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-12 16:26:43 -08:00
ec3befc028 [Device Mesh] Add an option to decouple PGs when it comes device mesh save
[ghstack-poisoned]
2025-11-11 15:47:04 -08:00
306aa9c2a4 Update on "[Device Mesh][ez] Clean up unused parameters and duplicate codes"
While refactoring the code, I found we re-init `_flatten_mapping` and still keep `_flatten_mesh_list ` inside code which is not needed anymore. Let's remove it.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-11 14:49:07 -08:00
fc1469be71 [Device Mesh][ez] Clean up unused parameters and duplicate codes
[ghstack-poisoned]
2025-11-11 14:16:34 -08:00
2 changed files with 123 additions and 10 deletions

View File

@ -10,9 +10,8 @@ from numpy.testing import assert_array_equal
import torch
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
@ -554,6 +553,31 @@ class DTensorTest(DTensorTestBase):
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
@with_comms
def test_dtensor_save_load_with_mesh_backend_decouple(self):
import io
# Turn on gate for not saving PG names for device mesh when it comes to torch.save.
DeviceMesh.decouple_backend_at_save = True
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=False)
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
reloaded_st._spec.mesh = device_mesh
# We will change this to be not Equal in the following PR.
self.assertEqual(sharded_tensor, reloaded_st)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
reloaded_st._spec.mesh = device_mesh
self.assertEqual(sharded_tensor, reloaded_st)
DeviceMesh.decouple_backend_at_save = False
@skipIfHpu
@with_comms
@unittest.skipIf(
@ -641,6 +665,7 @@ DTensorTestWithLocalTensor = create_local_tensor_test_class(
# integration
"test_dtensor_save_load",
"test_dtensor_save_load_import",
"test_dtensor_save_load_with_mesh_backend_decouple",
],
)

View File

@ -6,7 +6,7 @@ import threading
import warnings
from collections.abc import Iterator
from itertools import zip_longest
from typing import Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
@ -173,6 +173,9 @@ else:
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
"""
# Flag to specify device save without backend info. This is a temporary variable
# We will remove this flag once we fully deprecate the behavior of save a device mesh with pg names.
decouple_backend_at_save = False
_device_type: str
_rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
@ -255,14 +258,13 @@ else:
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._flatten_rank_map = tuple(self._rank_map.tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
@ -293,11 +295,6 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""
@ -1239,6 +1236,97 @@ else:
res_mesh._dim_group_names = concat_dim_group_name
return res_mesh
def __getstate__(self):
"""
Returns the state of the DeviceMesh as a dictionary for serialization,
which contains all necessary information to reconstruct the DeviceMesh.
"""
state: dict[str, Any] = {
"device_type": self._device_type,
"rank_map": self._rank_map,
"layout": self._layout,
"mesh_dim_names": self._mesh_dim_names,
"thread_id": self._thread_id,
"coordinate_on_dim": getattr(self, "_coordinate_on_dim", None),
}
# Serialize root_mesh if it exists
# To avoid infinite recursion (root -> child -> root), only serialize if this is not the root
if self._root_mesh is not None:
state["root_mesh"] = self._root_mesh.__getstate__()
else:
state["root_mesh"] = None
# Serialize flatten_mapping
flatten_mapping: dict[str, Any] = {}
for mesh_name, mesh in self._flatten_mapping.items():
flatten_mapping[mesh_name] = mesh.__getstate__()
state["flatten_mapping"] = flatten_mapping
if not self.decouple_backend_at_save and hasattr(self, "_dim_group_names"):
logger.warning(
"Save device mesh via torch.save with pg names and will be deprecated in PT 2.11. "
"Users are welcome to use Distributed checkpoint (DCP) or re-create pgs in the same order"
"as the original device mesh."
)
state["dim_group_names"] = self._dim_group_names
return state
def __setstate__(self, state):
"""
Restores the DeviceMesh state from a state dictionary.
"""
required_keys = {
"device_type",
"rank_map",
"layout",
"mesh_dim_names",
"thread_id",
"coordinate_on_dim",
"root_mesh",
"flatten_mapping",
}
missing_keys = required_keys - state.keys()
if missing_keys:
raise ValueError(f"state_dict is missing required keys: {missing_keys}")
# Restore basic attributes
self._device_type = state["device_type"]
self._rank_map = state["rank_map"]
self._layout = state["layout"]
self._mesh_dim_names = state["mesh_dim_names"]
self._thread_id = state["thread_id"]
if state.get("coordinate_on_dim") is not None:
self._coordinate_on_dim = state["coordinate_on_dim"]
# Restore root_mesh if it exists
if state.get("root_mesh") is not None:
# Create a new DeviceMesh for the root mesh
root_mesh = DeviceMesh.__new__(DeviceMesh)
root_mesh.__setstate__(state["root_mesh"])
self._root_mesh = root_mesh
else:
self._root_mesh = None
# Re-initialize internal bookkeeping
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Restore flatten_mapping
self._flatten_mapping = {}
if state.get("flatten_mapping"):
for mesh_name, mesh_state in state["flatten_mapping"].items():
flatten_mesh = DeviceMesh.__new__(DeviceMesh)
flatten_mesh.__setstate__(mesh_state)
self._flatten_mapping[mesh_name] = flatten_mesh
# We don't recommend load from saved pg names, because users need to ensure the same
# order in creating process groups when we save the device mesh.
# This is implicit and error-prone. We will remove this behavior soon.
# What we recommend users to do is to explicitly create PGs and set it to the loaded mesh.
if state.get("dim_group_names"):
self._dim_group_names = state["dim_group_names"]
def _normalize_backend_override(
backend_override: dict[
Union[int, str],