mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 23:04:54 +08:00
Compare commits
78 Commits
tianren/sy
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 54abae7c07 | |||
| 6a5dc667a4 | |||
| 6aca75eab9 | |||
| f8d0bef572 | |||
| 23099ad498 | |||
| 1842dde349 | |||
| 63375b0adb | |||
| 2cc9ff2f6a | |||
| 75eb58ee1d | |||
| 6d41a72865 | |||
| e4ad59038a | |||
| 150eac735f | |||
| 8d5f429f3b | |||
| 2ea430ce25 | |||
| e186bca40f | |||
| 602f1f42c2 | |||
| 24e0f27711 | |||
| b9a6e020a2 | |||
| 310db76d36 | |||
| 4a70f2d033 | |||
| 82fc673005 | |||
| e587947b20 | |||
| ff0ebf7fe5 | |||
| 7a85a3d289 | |||
| baf793c54d | |||
| e0c47235a3 | |||
| f6b4fe1c64 | |||
| 092d3778a3 | |||
| 82b416f4db | |||
| 36233be9d3 | |||
| f56e3c8fdf | |||
| ad324d0e3c | |||
| a4dc3dafee | |||
| 766493267e | |||
| bd086ac1b3 | |||
| 411a5c7f7f | |||
| 3bab82c453 | |||
| 11a624dc28 | |||
| 4736bb57e3 | |||
| c2aaa5664c | |||
| 569a9000a5 | |||
| 73f13abc50 | |||
| 3038d7f285 | |||
| 57448253f3 | |||
| 32caf41d72 | |||
| 28185f4406 | |||
| 92c709c202 | |||
| f094af1e1a | |||
| 09db0ef757 | |||
| ecd3f525d5 | |||
| b2345c972f | |||
| fe44a87ed4 | |||
| f0aa9cfc42 | |||
| 0ea286d26d | |||
| 621b9f2be8 | |||
| 5ac6d410aa | |||
| 2bb8d19968 | |||
| 2cd038d95d | |||
| 5decc0e164 | |||
| bafbc39603 | |||
| 356abd0719 | |||
| 457bdbfaa4 | |||
| 5196ba3db4 | |||
| bf108cf3c9 | |||
| 389519a03c | |||
| 873ec8442e | |||
| 5841ede067 | |||
| 4c90367bfb | |||
| 27910fb22a | |||
| e5cbce1780 | |||
| 4dd9c8cf2d | |||
| 8c0643a671 | |||
| 0f7736dd55 | |||
| 9e8f81aaa8 | |||
| 72706c7cb9 | |||
| 01a3a8550a | |||
| 00e03cccf3 | |||
| 1a58e8dad5 |
@ -10,9 +10,8 @@ from numpy.testing import assert_array_equal
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
@ -554,6 +553,40 @@ class DTensorTest(DTensorTestBase):
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_save_load_with_mesh_backend_decouple(self):
|
||||
import io
|
||||
|
||||
# Turn on gate for not saving PG names for device mesh when it comes to torch.save.
|
||||
DeviceMesh.decouple_backend_at_save = True
|
||||
device_mesh = self.build_device_mesh()
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
buffer = io.BytesIO()
|
||||
torch.save(sharded_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=False)
|
||||
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
|
||||
self.assertNotEqual(sharded_tensor._spec.mesh, reloaded_st._spec.mesh)
|
||||
self.assertEqual(
|
||||
sharded_tensor.to_local().tolist(), reloaded_st.to_local().tolist()
|
||||
)
|
||||
self.assertEqual(sharded_tensor._spec.placements, reloaded_st._spec.placements)
|
||||
reloaded_st._spec.mesh = device_mesh
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
|
||||
self.assertNotEqual(sharded_tensor._spec.mesh, reloaded_st._spec.mesh)
|
||||
self.assertEqual(
|
||||
sharded_tensor.to_local().tolist(), reloaded_st.to_local().tolist()
|
||||
)
|
||||
self.assertEqual(sharded_tensor._spec.placements, reloaded_st._spec.placements)
|
||||
reloaded_st._spec.mesh = device_mesh
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
DeviceMesh.decouple_backend_at_save = False
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms
|
||||
@unittest.skipIf(
|
||||
@ -641,6 +674,7 @@ DTensorTestWithLocalTensor = create_local_tensor_test_class(
|
||||
# integration
|
||||
"test_dtensor_save_load",
|
||||
"test_dtensor_save_load_import",
|
||||
"test_dtensor_save_load_with_mesh_backend_decouple",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -1051,6 +1051,26 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
)
|
||||
w.wait()
|
||||
|
||||
@with_comms
|
||||
def test_unflatten_mesh_3d_with_pg_cache(self):
|
||||
# Turn on gate for not saving PG names for device mesh when it comes to torch.save.
|
||||
# This also turns on pg cache
|
||||
DeviceMesh.decouple_backend_at_save = True
|
||||
# Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
|
||||
global_mesh = init_device_mesh(
|
||||
self.device_type,
|
||||
(8,),
|
||||
mesh_dim_names=("world",),
|
||||
)
|
||||
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
|
||||
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
|
||||
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
|
||||
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
|
||||
# test pg caching when unflatten into same layout.
|
||||
self.assertEqual(non_ep_mesh["dp"].get_group(), ep_mesh["dp"].get_group())
|
||||
self.assertEqual(non_ep_mesh["tp"].get_group(), ep_mesh["ep_tp"].get_group())
|
||||
DeviceMesh.decouple_backend_at_save = False
|
||||
|
||||
@with_comms
|
||||
def test_concatenate_2d(self):
|
||||
mesh_shape = (2, 4)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -48,6 +49,12 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
const std::string backend;
|
||||
std::string group_name;
|
||||
std::vector<uint64_t> global_ranks_in_group;
|
||||
|
||||
bool operator==(const Options& other) const noexcept {
|
||||
return timeout == other.timeout && backend == other.backend &&
|
||||
group_name == other.group_name &&
|
||||
global_ranks_in_group == other.global_ranks_in_group;
|
||||
}
|
||||
};
|
||||
|
||||
explicit Backend(int rank, int size);
|
||||
@ -511,3 +518,24 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
// small helper
|
||||
inline void hash_combine(std::size_t& seed, std::size_t value) noexcept {
|
||||
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
struct hash<c10d::Backend::Options> {
|
||||
std::size_t operator()(const c10d::Backend::Options& o) const noexcept {
|
||||
std::size_t h = 0;
|
||||
hash_combine(h, std::hash<long long>{}(o.timeout.count()));
|
||||
hash_combine(h, std::hash<std::string>{}(o.backend));
|
||||
hash_combine(h, std::hash<std::string>{}(o.group_name));
|
||||
for (auto x : o.global_ranks_in_group)
|
||||
hash_combine(h, std::hash<uint64_t>{}(x));
|
||||
return h;
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
@ -260,6 +260,23 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
|
||||
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
|
||||
int threads;
|
||||
|
||||
bool operator==(const Options& other) const noexcept {
|
||||
// 1) compare base first
|
||||
if (!static_cast<const Backend::Options&>(*this).operator==(other))
|
||||
return false;
|
||||
|
||||
// 2) compare devices by identity
|
||||
if (devices.size() != other.devices.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < devices.size(); ++i) {
|
||||
if (devices[i].get() != other.devices[i].get()) // pointer identity
|
||||
return false;
|
||||
}
|
||||
|
||||
// 3) compare added scalar fields
|
||||
return threads == other.threads;
|
||||
}
|
||||
};
|
||||
|
||||
const std::string getBackendName() const override {
|
||||
@ -494,4 +511,24 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10d::ProcessGroupGloo::Options> {
|
||||
std::size_t operator()(
|
||||
const c10d::ProcessGroupGloo::Options& o) const noexcept {
|
||||
std::size_t h = 0;
|
||||
// reuse base hash
|
||||
hash_combine(
|
||||
h,
|
||||
std::hash<c10d::Backend::Options>{}(
|
||||
static_cast<const c10d::Backend::Options&>(o)));
|
||||
// add derived fields
|
||||
for (auto const& dev : o.devices)
|
||||
hash_combine(h, std::hash<const void*>{}(dev.get()));
|
||||
hash_combine(h, std::hash<int>{}(o.threads));
|
||||
return h;
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
#endif // USE_C10D_GLOO
|
||||
|
||||
@ -550,6 +550,33 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead.
|
||||
int split_color{-2};
|
||||
#endif
|
||||
|
||||
bool operator==(const Options& other) const noexcept {
|
||||
// 1) compare base first
|
||||
if (!static_cast<const Backend::Options&>(*this).operator==(other))
|
||||
return false;
|
||||
|
||||
// 2) simple fields
|
||||
if (is_high_priority_stream != other.is_high_priority_stream) {
|
||||
return false;
|
||||
}
|
||||
if (split_color != other.split_color) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 3) split_from: compare by identity
|
||||
if (split_from.get() != other.split_from.get()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef NCCL_HAS_CONFIG
|
||||
// 4) config
|
||||
if (std::memcmp(&config, &other.config, sizeof(ncclConfig_t)) != 0) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
|
||||
@ -1504,4 +1531,46 @@ typedef bool (*gil_checker_t)();
|
||||
TORCH_API gil_checker_t& get_gil_checker();
|
||||
} // namespace c10d
|
||||
|
||||
#ifdef NCCL_HAS_CONFIG
|
||||
inline std::size_t hash_nccl_config(const ncclConfig_t& cfg) noexcept {
|
||||
const unsigned char* p = reinterpret_cast<const unsigned char*>(&cfg);
|
||||
std::size_t h = 0;
|
||||
for (std::size_t i = 0; i < sizeof(cfg); ++i) {
|
||||
hash_combine(h, static_cast<std::size_t>(p[i]));
|
||||
}
|
||||
return h;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
struct hash<c10d::ProcessGroupNCCL::Options> {
|
||||
std::size_t operator()(
|
||||
const c10d::ProcessGroupNCCL::Options& o) const noexcept {
|
||||
std::size_t h = 0;
|
||||
|
||||
// 1) base
|
||||
hash_combine(
|
||||
h,
|
||||
std::hash<c10d::Backend::Options>{}(
|
||||
static_cast<const c10d::Backend::Options&>(o)));
|
||||
|
||||
// 2) trivial extras
|
||||
hash_combine(h, std::hash<bool>{}(o.is_high_priority_stream));
|
||||
hash_combine(h, std::hash<int>{}(o.split_color));
|
||||
|
||||
// 3) pointer identity for split_from
|
||||
hash_combine(h, std::hash<const void*>{}(o.split_from.get()));
|
||||
|
||||
#ifdef NCCL_HAS_CONFIG
|
||||
// 4) config — option A: hash bytes
|
||||
hash_combine(h, hash_nccl_config(o.config));
|
||||
#endif
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
||||
@ -3107,7 +3107,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
.def_readwrite(
|
||||
"global_ranks_in_group",
|
||||
&::c10d::Backend::Options::global_ranks_in_group)
|
||||
.def_readwrite("group_name", &::c10d::Backend::Options::group_name);
|
||||
.def_readwrite("group_name", &::c10d::Backend::Options::group_name)
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const ::c10d::Backend::Options& a,
|
||||
const ::c10d::Backend::Options& b) { return a == b; })
|
||||
.def("__hash__", [](const ::c10d::Backend::Options& a) {
|
||||
return std::hash<::c10d::Backend::Options>{}(a);
|
||||
});
|
||||
|
||||
#ifdef USE_C10D_GLOO
|
||||
auto processGroupGloo =
|
||||
@ -3121,7 +3128,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
processGroupGloo, "_Options", backendOptions)
|
||||
.def(py::init<>())
|
||||
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
|
||||
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
|
||||
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads)
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const ::c10d::ProcessGroupGloo::Options& a,
|
||||
const ::c10d::ProcessGroupGloo::Options& b) { return a == b; })
|
||||
.def("__hash__", [](const ::c10d::ProcessGroupGloo::Options& a) {
|
||||
return std::hash<::c10d::ProcessGroupGloo::Options>{}(a);
|
||||
});
|
||||
|
||||
processGroupGloo
|
||||
.def_static(
|
||||
@ -3481,6 +3495,15 @@ Example::
|
||||
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
|
||||
.def_readwrite(
|
||||
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const ::c10d::ProcessGroupNCCL::Options& a,
|
||||
const ::c10d::ProcessGroupNCCL::Options& b) { return a == b; })
|
||||
.def(
|
||||
"__hash__",
|
||||
[](const ::c10d::ProcessGroupNCCL::Options& a) {
|
||||
return std::hash<::c10d::ProcessGroupNCCL::Options>{}(a);
|
||||
})
|
||||
.def(
|
||||
"__copy__",
|
||||
[](const ::c10d::ProcessGroupNCCL::Options& self) {
|
||||
|
||||
@ -951,7 +951,9 @@ class _LocalDeviceMesh:
|
||||
|
||||
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
|
||||
for r in lm.ranks:
|
||||
rank_tensor = self._layout.remap_to_tensor(self._rank_map)
|
||||
rank_tensor = self._layout.remap_to_tensor(
|
||||
self._shared_state.get_rank_map()
|
||||
)
|
||||
rank_coords = (rank_tensor == r).nonzero().tolist()
|
||||
assert len(rank_coords) == 1
|
||||
for d, c in enumerate(rank_coords[0][1:]):
|
||||
|
||||
@ -6,7 +6,7 @@ import threading
|
||||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import is_available
|
||||
@ -125,264 +125,67 @@ else:
|
||||
"""
|
||||
return getattr(torch, device_type, None)
|
||||
|
||||
class DeviceMesh:
|
||||
class _SharedState:
|
||||
"""
|
||||
DeviceMesh represents a mesh of devices, where layout of devices could be
|
||||
represented as a n-d dimension array, and each value of the n-d dimensional
|
||||
array is the global id of the default process group ranks.
|
||||
|
||||
DeviceMesh could be used to setup the N dimensional device connections across the cluster,
|
||||
and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
|
||||
each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects
|
||||
already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization),
|
||||
and will select/set the device for the current process if user does not set the device
|
||||
beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
|
||||
|
||||
DeviceMesh can also be used as a context manager when using together with DTensor APIs.
|
||||
|
||||
.. note::
|
||||
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
|
||||
is running on all processes/ranks in the cluster. Therefore, users need to make sure the
|
||||
`mesh` array (which describes the layout of devices) should be identical across all ranks.
|
||||
Inconsistent `mesh` will lead to silent hang.
|
||||
|
||||
Args:
|
||||
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
|
||||
mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
|
||||
of devices, where the IDs are global IDs of the default process group.
|
||||
_rank (int): (experimental/internal)
|
||||
The global rank of the current process. If not provided, it will
|
||||
be inferred from the default process group.
|
||||
|
||||
Returns:
|
||||
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
|
||||
|
||||
The following program runs on each process/rank in an SPMD manner. In this example, we have 2
|
||||
hosts with 4 GPUs each.
|
||||
A reduction over the first dimension of mesh will reduce across
|
||||
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
||||
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>> from torch.distributed.device_mesh import DeviceMesh
|
||||
>>>
|
||||
>>> # Initialize device mesh as (2, 4) to represent the topology
|
||||
>>> # of cross-host(dim 0), and within-host (dim 1).
|
||||
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
||||
This class is used to store the shared state of the DeviceMesh.
|
||||
"""
|
||||
|
||||
# Flag to specify device save without backend info. This is a temporary variable
|
||||
# We will remove this flag once we fully deprecate the behavior of save a device mesh with pg names.
|
||||
_device_type: str
|
||||
_rank_map: torch.Tensor
|
||||
_mesh_dim_names: Optional[tuple[str, ...]]
|
||||
_layout: _MeshLayout
|
||||
_root_mesh: Optional["DeviceMesh"] = None
|
||||
# Record flatten mesh name to its flattened mesh in root mesh.
|
||||
_flatten_mapping: dict[str, "DeviceMesh"]
|
||||
_root_mesh: Optional["DeviceMesh"]
|
||||
_backend_cache: dict[tuple[_MeshLayout, Optional[C10dBackend.Options]], str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
||||
*,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_rank_map: Optional[torch.Tensor] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
rank_map: torch.Tensor,
|
||||
root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
if mesh is not None:
|
||||
if _layout is not None or _rank_map is not None:
|
||||
raise TypeError(
|
||||
"Cannot provide _layout and/or _rank_map if passing explicit mesh"
|
||||
)
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
mesh_tensor = (
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
||||
_rank_map = mesh_tensor.flatten()
|
||||
else:
|
||||
if _layout is None or _rank_map is None:
|
||||
raise TypeError(
|
||||
"The mesh argument is required except for PRIVATE USAGE ONLY!"
|
||||
)
|
||||
|
||||
assert _layout.check_non_overlap(), (
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
|
||||
assert _rank_map.is_contiguous(), "The rank map must be contiguous"
|
||||
assert _rank_map.numel() >= _layout.cosize(), (
|
||||
f"The rank map contains {_rank_map.numel()} element, "
|
||||
f"which isn't large enough for layout {_layout}"
|
||||
)
|
||||
|
||||
self._device_type = device_type
|
||||
self._layout = _layout
|
||||
self._rank_map = _rank_map
|
||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
self._root_mesh = _root_mesh
|
||||
self._rank_map = rank_map
|
||||
self._root_mesh = root_mesh
|
||||
self._backend_cache: dict[
|
||||
tuple[_MeshLayout, Optional[C10dBackend.Options]], str
|
||||
] = {}
|
||||
self.pg_cache_enabled = DeviceMesh.decouple_backend_at_save
|
||||
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * len(self._layout)
|
||||
elif len(backend_override) != len(self._layout):
|
||||
raise ValueError(
|
||||
f"backend_override should have the same length as the number of mesh dimensions, "
|
||||
f"but got {len(backend_override)} and {len(self._layout)}."
|
||||
)
|
||||
# Internal bookkeeping for the device mesh.
|
||||
self._layout = (
|
||||
_layout
|
||||
if _layout
|
||||
else _MeshLayout(self.mesh.size(), self.mesh.stride())
|
||||
)
|
||||
if not self._layout.check_non_overlap():
|
||||
raise AssertionError(
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
if self._layout.numel() != self.mesh.numel():
|
||||
raise AssertionError(
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
)
|
||||
def __post_init__(self):
|
||||
assert self._rank_map.ndim == 1, "The rank map must be 1-dimensional"
|
||||
assert self._rank_map.is_contiguous(), "The rank map must be contiguous"
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._thread_id = None
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
def get_rank_map(self) -> torch.Tensor:
|
||||
return self._rank_map
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
self._thread_id = None
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
if _init_backend:
|
||||
self._setup_world_group_and_device()
|
||||
self._dim_group_names = self._init_process_groups(
|
||||
self._layout,
|
||||
self._rank_map,
|
||||
self._mesh_dim_names,
|
||||
backend_override,
|
||||
)
|
||||
def get_root_mesh(self) -> Optional["DeviceMesh"]:
|
||||
return self._root_mesh
|
||||
|
||||
if is_initialized() and get_backend() == "threaded":
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self._thread_id = threading.get_ident()
|
||||
|
||||
if _rank is None:
|
||||
_rank = get_rank()
|
||||
|
||||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == _rank).nonzero()
|
||||
if rank_coords.size(0) not in (0, 1):
|
||||
raise AssertionError(
|
||||
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
|
||||
)
|
||||
self._coordinate_on_dim: Optional[list[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_rank_map = tuple(self._rank_map.tolist())
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""Returns the device type of the mesh."""
|
||||
def get_device_type(self) -> str:
|
||||
return self._device_type
|
||||
|
||||
@property
|
||||
def mesh(self) -> torch.Tensor:
|
||||
"""Returns the tensor representing the layout of devices."""
|
||||
full_mesh = self._layout.remap_to_tensor(self._rank_map)
|
||||
if full_mesh.size(0) == 1:
|
||||
return full_mesh[0]
|
||||
my_coords = (full_mesh == get_rank()).nonzero()
|
||||
if my_coords.size(0) > 0:
|
||||
return full_mesh[my_coords[0, 0]]
|
||||
raise RuntimeError(
|
||||
"In order to get the mesh Tensor of a DeviceMesh it needs to "
|
||||
"either have all its original dimensions (e.g., no slicing) "
|
||||
"or it needs to contain the local rank"
|
||||
)
|
||||
def update_backend_cache(
|
||||
self,
|
||||
layout: _MeshLayout,
|
||||
backend: str,
|
||||
pg_option: Optional[C10dBackend.Options],
|
||||
) -> None:
|
||||
if (layout, pg_option) not in self._backend_cache:
|
||||
self._backend_cache[(layout, pg_option)] = backend
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
||||
"""Returns the names of mesh dimensions."""
|
||||
return self._mesh_dim_names
|
||||
def get_backend_from_cache(
|
||||
self, layout: _MeshLayout, pg_option: Optional[C10dBackend.Options]
|
||||
) -> Optional[str]:
|
||||
return self._backend_cache.get((layout, pg_option), None)
|
||||
|
||||
def _setup_world_group_and_device(self):
|
||||
default_initialized = is_initialized()
|
||||
# TODO: think about how to allow pg options to be passed to world group
|
||||
# or mesh dimension groups
|
||||
if not default_initialized:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self._layout.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
|
||||
)
|
||||
|
||||
# ONLY set the device if the current device is not initialized, if user already
|
||||
# set the device before DeviceMesh init, we respect the user's choice.
|
||||
device_handle = _get_device_handle(self._device_type)
|
||||
if device_handle and not device_handle.is_initialized():
|
||||
# auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK
|
||||
# env variable from launchers, we use it to set the device.
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
logger.info(
|
||||
"Setting default device for the current process based on LOCAL_RANK=%s",
|
||||
local_rank,
|
||||
)
|
||||
device_handle.set_device(local_rank)
|
||||
else:
|
||||
warnings.warn(
|
||||
"It seems like you did not set/select the default device for the current process before the DeviceMesh "
|
||||
"initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. "
|
||||
"It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
|
||||
"the underlying communicator (i.e. NCCL) can be initialized properly. "
|
||||
"Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
|
||||
"device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. ",
|
||||
stacklevel=2,
|
||||
)
|
||||
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
|
||||
# NOTE: This device selection would only work for homogeneous hardware.
|
||||
num_devices_per_host = device_handle.device_count()
|
||||
if (
|
||||
world_size > num_devices_per_host
|
||||
and world_size % num_devices_per_host != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh only support homogeneous hardware, but found "
|
||||
f"{world_size} ranks and {num_devices_per_host} {self._device_type} devices!"
|
||||
)
|
||||
device_handle.set_device(get_rank() % num_devices_per_host)
|
||||
|
||||
return _get_default_group()
|
||||
|
||||
@staticmethod
|
||||
def _init_one_process_group(
|
||||
self,
|
||||
sub_layout: _MeshLayout,
|
||||
rank_map: torch.Tensor,
|
||||
dim_name: str,
|
||||
backend_override: BackendConfig,
|
||||
) -> Optional[str]:
|
||||
# Generate a 2D global mesh tensor for the current dim for PG creation.
|
||||
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
|
||||
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(self._rank_map)
|
||||
backend, pg_options = backend_override
|
||||
# We need to explicitly pass in timeout when specified in option, otherwise
|
||||
# the default timeout will be used to override the timeout set in option.
|
||||
@ -469,10 +272,9 @@ else:
|
||||
pg_name = dim_group.group_name
|
||||
return pg_name
|
||||
|
||||
@staticmethod
|
||||
def _init_process_groups(
|
||||
self,
|
||||
layout: _MeshLayout,
|
||||
rank_map: torch.Tensor,
|
||||
mesh_dim_names: Optional[tuple[str, ...]],
|
||||
backend_override: tuple[BackendConfig, ...],
|
||||
) -> list[str]:
|
||||
@ -482,18 +284,270 @@ else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(len(layout)):
|
||||
dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}"
|
||||
dim_group_names.append(
|
||||
DeviceMesh._init_one_process_group( # type: ignore[arg-type]
|
||||
layout[dim], rank_map, dim_name, backend_override[dim]
|
||||
backend_cache = None
|
||||
if self.pg_cache_enabled:
|
||||
backend_cache = self.get_backend_from_cache(
|
||||
layout[dim], backend_override[dim][1]
|
||||
)
|
||||
if backend_cache is not None:
|
||||
dim_group_names.append(backend_cache)
|
||||
else:
|
||||
dim_group_names.append(
|
||||
self._init_one_process_group( # type: ignore[arg-type]
|
||||
layout[dim], dim_name, backend_override[dim]
|
||||
)
|
||||
)
|
||||
if dim_group_names[-1] is not None and self.pg_cache_enabled:
|
||||
self.update_backend_cache(
|
||||
layout[dim], dim_group_names[-1], backend_override[dim][1]
|
||||
)
|
||||
)
|
||||
if any(n is None for n in dim_group_names):
|
||||
assert all(n is None for n in dim_group_names)
|
||||
return []
|
||||
return dim_group_names
|
||||
|
||||
torch.serialization.add_safe_globals([_SharedState])
|
||||
|
||||
class DeviceMesh:
|
||||
"""
|
||||
DeviceMesh represents a mesh of devices, where layout of devices could be
|
||||
represented as a n-d dimension array, and each value of the n-d dimensional
|
||||
array is the global id of the default process group ranks.
|
||||
|
||||
DeviceMesh could be used to setup the N dimensional device connections across the cluster,
|
||||
and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
|
||||
each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects
|
||||
already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization),
|
||||
and will select/set the device for the current process if user does not set the device
|
||||
beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
|
||||
|
||||
DeviceMesh can also be used as a context manager when using together with DTensor APIs.
|
||||
|
||||
.. note::
|
||||
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
|
||||
is running on all processes/ranks in the cluster. Therefore, users need to make sure the
|
||||
`mesh` array (which describes the layout of devices) should be identical across all ranks.
|
||||
Inconsistent `mesh` will lead to silent hang.
|
||||
|
||||
Args:
|
||||
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
|
||||
mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
|
||||
of devices, where the IDs are global IDs of the default process group.
|
||||
_rank (int): (experimental/internal)
|
||||
The global rank of the current process. If not provided, it will
|
||||
be inferred from the default process group.
|
||||
|
||||
Returns:
|
||||
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
|
||||
|
||||
The following program runs on each process/rank in an SPMD manner. In this example, we have 2
|
||||
hosts with 4 GPUs each.
|
||||
A reduction over the first dimension of mesh will reduce across
|
||||
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
||||
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>> from torch.distributed.device_mesh import DeviceMesh
|
||||
>>>
|
||||
>>> # Initialize device mesh as (2, 4) to represent the topology
|
||||
>>> # of cross-host(dim 0), and within-host (dim 1).
|
||||
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
||||
"""
|
||||
|
||||
# Flag to specify device save without backend info. This is a temporary variable
|
||||
# We will remove this flag once we fully deprecate the behavior of save a device mesh with pg names.
|
||||
decouple_backend_at_save = False
|
||||
_mesh_dim_names: Optional[tuple[str, ...]]
|
||||
_layout: _MeshLayout
|
||||
# Record flatten mesh name to its flattened mesh in root mesh.
|
||||
_flatten_mapping: dict[str, "DeviceMesh"]
|
||||
_shared_state: _SharedState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
||||
*,
|
||||
mesh_dim_names: Optional[tuple[str, ...]] = None,
|
||||
backend_override: Optional[tuple[BackendConfig, ...]] = None,
|
||||
_init_backend: bool = True,
|
||||
_rank: Optional[int] = None,
|
||||
_layout: Optional[_MeshLayout] = None,
|
||||
_shared_state: Optional[_SharedState] = None,
|
||||
) -> None:
|
||||
if mesh is not None:
|
||||
if _layout is not None or _shared_state is not None:
|
||||
raise TypeError(
|
||||
"Cannot provide _layout and/or _shared_state if passing explicit mesh"
|
||||
)
|
||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||
mesh_tensor = (
|
||||
mesh.detach().to(dtype=torch.int).contiguous()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||
)
|
||||
_layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
||||
rank_map = mesh_tensor.flatten()
|
||||
self._shared_state = _SharedState(
|
||||
device_type=device_type, rank_map=rank_map, root_mesh=self
|
||||
)
|
||||
else:
|
||||
if _layout is None or _shared_state is None:
|
||||
raise TypeError(
|
||||
"The mesh argument is required except for PRIVATE USAGE ONLY!"
|
||||
)
|
||||
rank_map = _shared_state.get_rank_map()
|
||||
self._shared_state = _shared_state
|
||||
if self._shared_state.get_root_mesh() is None:
|
||||
self._shared_state._root_mesh = self
|
||||
|
||||
if not _layout.check_non_overlap():
|
||||
raise AssertionError(
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
|
||||
# Internal bookkeeping for the device mesh.
|
||||
self._layout = _layout
|
||||
assert self._shared_state.get_rank_map().numel() >= self._layout.cosize(), (
|
||||
f"The rank map contains {rank_map.numel()} element, "
|
||||
f"which isn't large enough for layout {self._layout}"
|
||||
)
|
||||
|
||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||
if backend_override is None:
|
||||
backend_override = ((None, None),) * len(self._layout)
|
||||
elif len(backend_override) != len(self._layout):
|
||||
raise ValueError(
|
||||
f"backend_override should have the same length as the number of mesh dimensions, "
|
||||
f"but got {len(backend_override)} and {len(self._layout)}."
|
||||
)
|
||||
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
if self._layout.numel() != self.mesh.numel():
|
||||
raise AssertionError(
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_rank_map = tuple(self._shared_state.get_rank_map().tolist())
|
||||
self._thread_id = None
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
if _init_backend:
|
||||
self._setup_world_group_and_device()
|
||||
self._dim_group_names = self._shared_state._init_process_groups(
|
||||
self._layout,
|
||||
self._mesh_dim_names,
|
||||
backend_override,
|
||||
)
|
||||
|
||||
if is_initialized() and get_backend() == "threaded":
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self._thread_id = threading.get_ident()
|
||||
|
||||
if _rank is None:
|
||||
_rank = get_rank()
|
||||
|
||||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == _rank).nonzero()
|
||||
if rank_coords.size(0) not in (0, 1):
|
||||
raise AssertionError(
|
||||
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
|
||||
)
|
||||
self._coordinate_on_dim: Optional[list[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""Returns the device type of the mesh."""
|
||||
return self._shared_state.get_device_type()
|
||||
|
||||
@property
|
||||
def mesh(self) -> torch.Tensor:
|
||||
"""Returns the tensor representing the layout of devices."""
|
||||
full_mesh = self._layout.remap_to_tensor(self._shared_state.get_rank_map())
|
||||
if full_mesh.size(0) == 1:
|
||||
return full_mesh[0]
|
||||
my_coords = (full_mesh == get_rank()).nonzero()
|
||||
if my_coords.size(0) > 0:
|
||||
return full_mesh[my_coords[0, 0]]
|
||||
raise RuntimeError(
|
||||
"In order to get the mesh Tensor of a DeviceMesh it needs to "
|
||||
"either have all its original dimensions (e.g., no slicing) "
|
||||
"or it needs to contain the local rank"
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
||||
"""Returns the names of mesh dimensions."""
|
||||
return self._mesh_dim_names
|
||||
|
||||
def _setup_world_group_and_device(self):
|
||||
default_initialized = is_initialized()
|
||||
# TODO: think about how to allow pg options to be passed to world group
|
||||
# or mesh dimension groups
|
||||
if not default_initialized:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self._layout.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
|
||||
)
|
||||
|
||||
# ONLY set the device if the current device is not initialized, if user already
|
||||
# set the device before DeviceMesh init, we respect the user's choice.
|
||||
device_handle = _get_device_handle(self._shared_state.get_device_type())
|
||||
if device_handle and not device_handle.is_initialized():
|
||||
# auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK
|
||||
# env variable from launchers, we use it to set the device.
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
logger.info(
|
||||
"Setting default device for the current process based on LOCAL_RANK=%s",
|
||||
local_rank,
|
||||
)
|
||||
device_handle.set_device(local_rank)
|
||||
else:
|
||||
warnings.warn(
|
||||
"It seems like you did not set/select the default device for the current process before the DeviceMesh "
|
||||
"initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. "
|
||||
"It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
|
||||
"the underlying communicator (i.e. NCCL) can be initialized properly. "
|
||||
"Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
|
||||
"device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. ",
|
||||
stacklevel=2,
|
||||
)
|
||||
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
|
||||
# NOTE: This device selection would only work for homogeneous hardware.
|
||||
num_devices_per_host = device_handle.device_count()
|
||||
if (
|
||||
world_size > num_devices_per_host
|
||||
and world_size % num_devices_per_host != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh only support homogeneous hardware, but found "
|
||||
f"{world_size} ranks and {num_devices_per_host} {self._shared_state.get_device_type()} devices!"
|
||||
)
|
||||
device_handle.set_device(get_rank() % num_devices_per_host)
|
||||
|
||||
return _get_default_group()
|
||||
|
||||
def _get_root_mesh(self) -> "DeviceMesh":
|
||||
return self._root_mesh if self._root_mesh else self
|
||||
root_mesh = self._shared_state.get_root_mesh()
|
||||
return root_mesh if root_mesh is not None else self
|
||||
|
||||
def __enter__(self) -> "DeviceMesh":
|
||||
# set this mesh as the current mesh in mesh env
|
||||
@ -523,9 +577,9 @@ else:
|
||||
if not self._hash:
|
||||
self._hash = hash(
|
||||
(
|
||||
self._flatten_rank_map,
|
||||
self._get_universe_id(),
|
||||
self._layout,
|
||||
self._device_type,
|
||||
self._shared_state.get_device_type(),
|
||||
self._mesh_dim_names,
|
||||
self._thread_id,
|
||||
)
|
||||
@ -538,9 +592,10 @@ else:
|
||||
if not isinstance(other, DeviceMesh):
|
||||
return False
|
||||
return (
|
||||
self._flatten_rank_map == other._flatten_rank_map
|
||||
self._get_universe_id() == other._get_universe_id()
|
||||
and self._layout == other._layout
|
||||
and self._device_type == other._device_type
|
||||
and self._shared_state.get_device_type()
|
||||
== other._shared_state.get_device_type()
|
||||
and self._mesh_dim_names == other._mesh_dim_names
|
||||
and self._thread_id == other._thread_id
|
||||
)
|
||||
@ -695,12 +750,11 @@ else:
|
||||
]
|
||||
)
|
||||
res_submesh = DeviceMesh(
|
||||
self._device_type,
|
||||
self._shared_state.get_device_type(),
|
||||
_layout=layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=submesh_dim_names,
|
||||
_root_mesh=root_mesh,
|
||||
_init_backend=False,
|
||||
_shared_state=root_mesh._shared_state,
|
||||
)
|
||||
res_submesh._dim_group_names = slice_dim_group_name
|
||||
return res_submesh
|
||||
@ -745,12 +799,11 @@ else:
|
||||
)
|
||||
|
||||
res_flattened_mesh = DeviceMesh(
|
||||
root_mesh._device_type,
|
||||
root_mesh._shared_state.get_device_type(),
|
||||
_layout=flattened_mesh_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_root_mesh=root_mesh,
|
||||
backend_override=(backend_override,),
|
||||
_shared_state=root_mesh._shared_state,
|
||||
)
|
||||
root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
|
||||
|
||||
@ -874,12 +927,12 @@ else:
|
||||
"""
|
||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
||||
layout = self._layout[mesh_dim]
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(self._shared_state.get_rank_map())
|
||||
cur_rank = self.get_rank()
|
||||
res_submeshes = []
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
submesh = DeviceMesh(
|
||||
self._device_type,
|
||||
self._shared_state.get_device_type(),
|
||||
mesh_1d,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_backend=False,
|
||||
@ -1051,6 +1104,12 @@ else:
|
||||
)
|
||||
return not_none(get_rank(mesh_dim_group))
|
||||
|
||||
def _get_universe_id(self) -> Union[tuple[int, ...], int]:
|
||||
if self.decouple_backend_at_save:
|
||||
return id(self._shared_state.get_rank_map())
|
||||
else:
|
||||
return self._flatten_rank_map
|
||||
|
||||
def get_coordinate(self) -> Optional[list[int]]:
|
||||
"""
|
||||
Return the relative indices of this rank relative to all
|
||||
@ -1118,10 +1177,9 @@ else:
|
||||
res_mesh = DeviceMesh(
|
||||
self.device_type,
|
||||
_layout=unflattened_layout,
|
||||
_rank_map=root_mesh._rank_map,
|
||||
mesh_dim_names=tuple(unflattened_mesh_dim_names),
|
||||
_root_mesh=root_mesh,
|
||||
_init_backend=False,
|
||||
_shared_state=root_mesh._shared_state,
|
||||
)
|
||||
|
||||
# If original mesh has initiated its backend, we need to initialize the backend
|
||||
@ -1130,11 +1188,12 @@ else:
|
||||
# per dim backend init.
|
||||
if hasattr(self, "_dim_group_names"):
|
||||
dim_group_names = self._dim_group_names.copy()
|
||||
dim_group_names[dim : dim + 1] = self._init_process_groups(
|
||||
partial_layout,
|
||||
root_mesh._rank_map,
|
||||
mesh_dim_names,
|
||||
backend_override,
|
||||
dim_group_names[dim : dim + 1] = (
|
||||
root_mesh._shared_state._init_process_groups(
|
||||
partial_layout,
|
||||
mesh_dim_names,
|
||||
backend_override,
|
||||
)
|
||||
)
|
||||
res_mesh._dim_group_names = dim_group_names
|
||||
|
||||
@ -1210,7 +1269,7 @@ else:
|
||||
concat_sizes: list[IntTuple] = []
|
||||
concat_strides: list[IntTuple] = []
|
||||
concat_dim_group_name: list[str] = []
|
||||
flatten_rank_map = device_mesh_list[0]._flatten_rank_map
|
||||
mesh_universe_id = device_mesh_list[0]._get_universe_id()
|
||||
for dm in device_mesh_list:
|
||||
for i in range(len(dm._layout)):
|
||||
concat_sizes.append(dm._layout[i].sizes)
|
||||
@ -1219,7 +1278,7 @@ else:
|
||||
concat_dim_group_name.extend(not_none(dm._dim_group_names))
|
||||
# Concatenate device mesh having different root mesh tensors are meaningless
|
||||
# because the concatenated indices should be indexed by the same root mesh tensor.
|
||||
if dm._flatten_rank_map != flatten_rank_map:
|
||||
if dm._get_universe_id() != mesh_universe_id:
|
||||
raise RuntimeError(
|
||||
"Cannot concatenate DeviceMeshes derived from different device meshs"
|
||||
)
|
||||
@ -1231,14 +1290,109 @@ else:
|
||||
res_mesh = DeviceMesh(
|
||||
device_mesh_list[0].device_type,
|
||||
_layout=concat_mesh_layout,
|
||||
_rank_map=device_mesh_list[0]._rank_map,
|
||||
mesh_dim_names=tuple(concat_dim_names),
|
||||
_root_mesh=device_mesh_list[0]._get_root_mesh(),
|
||||
_init_backend=False,
|
||||
_shared_state=device_mesh_list[0]._shared_state,
|
||||
)
|
||||
res_mesh._dim_group_names = concat_dim_group_name
|
||||
return res_mesh
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Returns the state of the DeviceMesh as a dictionary for serialization,
|
||||
which contains all necessary information to reconstruct the DeviceMesh.
|
||||
"""
|
||||
shared_state = {
|
||||
"device_type": self._shared_state._device_type,
|
||||
"rank_map": self._shared_state._rank_map,
|
||||
}
|
||||
if self._shared_state._root_mesh != self:
|
||||
shared_state["root_mesh"] = not_none(
|
||||
self._shared_state._root_mesh
|
||||
).__getstate__()
|
||||
|
||||
state: dict[str, Any] = {
|
||||
"shared_state": shared_state,
|
||||
"layout": self._layout,
|
||||
"mesh_dim_names": self._mesh_dim_names,
|
||||
"thread_id": self._thread_id,
|
||||
"coordinate_on_dim": getattr(self, "_coordinate_on_dim", None),
|
||||
}
|
||||
|
||||
# Serialize flatten_mapping
|
||||
flatten_mapping: dict[str, Any] = {}
|
||||
for mesh_name, mesh in self._flatten_mapping.items():
|
||||
flatten_mapping[mesh_name] = mesh.__getstate__()
|
||||
state["flatten_mapping"] = flatten_mapping
|
||||
|
||||
if not self.decouple_backend_at_save and hasattr(self, "_dim_group_names"):
|
||||
logger.warning(
|
||||
"Save device mesh via torch.save with pg names and will be deprecated in PT 2.11. "
|
||||
"Users are welcome to use Distributed checkpoint (DCP) or re-create pgs in the same order"
|
||||
"as the original device mesh."
|
||||
)
|
||||
state["dim_group_names"] = self._dim_group_names
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""
|
||||
Restores the DeviceMesh state from a state dictionary.
|
||||
"""
|
||||
required_keys = {
|
||||
"shared_state",
|
||||
"layout",
|
||||
"mesh_dim_names",
|
||||
"thread_id",
|
||||
"coordinate_on_dim",
|
||||
"flatten_mapping",
|
||||
}
|
||||
missing_keys = required_keys - state.keys()
|
||||
if missing_keys:
|
||||
raise ValueError(f"state_dict is missing required keys: {missing_keys}")
|
||||
|
||||
# Restore shared_state
|
||||
shared_state = state["shared_state"]
|
||||
|
||||
# First, restore root_mesh if it exists (we need to do this before creating _SharedState)
|
||||
root_mesh = None
|
||||
if shared_state.get("root_mesh") is not None:
|
||||
# Create a new DeviceMesh for the root mesh
|
||||
root_mesh = DeviceMesh.__new__(DeviceMesh)
|
||||
root_mesh.__setstate__(shared_state["root_mesh"])
|
||||
|
||||
# Create and initialize the shared state
|
||||
self._shared_state = _SharedState(
|
||||
device_type=shared_state["device_type"],
|
||||
rank_map=shared_state["rank_map"],
|
||||
root_mesh=root_mesh,
|
||||
)
|
||||
|
||||
# Restore other attributes
|
||||
self._layout = state["layout"]
|
||||
self._mesh_dim_names = state["mesh_dim_names"]
|
||||
self._thread_id = state["thread_id"]
|
||||
if state.get("coordinate_on_dim") is not None:
|
||||
self._coordinate_on_dim = state["coordinate_on_dim"]
|
||||
|
||||
# Re-initialize internal bookkeeping
|
||||
self._flatten_rank_map = tuple(self._shared_state._rank_map.tolist())
|
||||
|
||||
# Restore flatten_mapping
|
||||
self._flatten_mapping = {}
|
||||
if state.get("flatten_mapping"):
|
||||
for mesh_name, mesh_state in state["flatten_mapping"].items():
|
||||
flatten_mesh = DeviceMesh.__new__(DeviceMesh)
|
||||
flatten_mesh.__setstate__(mesh_state)
|
||||
self._flatten_mapping[mesh_name] = flatten_mesh
|
||||
|
||||
# We don't recommend load from saved pg names, because users need to ensure the same
|
||||
# order in creating process groups when we save the device mesh.
|
||||
# This is implicit and error-prone. We will remove this behavior soon.
|
||||
# What we recommend users to do is to explicitly create PGs and set it to the loaded mesh.
|
||||
if state.get("dim_group_names"):
|
||||
self._dim_group_names = state["dim_group_names"]
|
||||
|
||||
def _normalize_backend_override(
|
||||
backend_override: dict[
|
||||
Union[int, str],
|
||||
@ -1363,12 +1517,13 @@ else:
|
||||
# external device type has been set to be (e.g. meta)
|
||||
with torch.device("cpu"):
|
||||
rank_map = torch.arange(layout.numel(), dtype=torch.int)
|
||||
shared_state = _SharedState(device_type=device_type, rank_map=rank_map)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
_layout=layout,
|
||||
_rank_map=rank_map,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
backend_override=backend_override_tuple,
|
||||
_shared_state=shared_state,
|
||||
)
|
||||
|
||||
return device_mesh
|
||||
|
||||
Reference in New Issue
Block a user