Files
pytorch/torch/distributed/_mesh_layout.py
fduwjj 19a4ef0256 [DeviceMesh] Make CuTe layout as mesh layout to be ready for using in DeviceMesh (#162414)
We create a wrapper class named "_MeshLayout" acting as a layout for device mesh so that we can add new methods more specific to DeviceMesh and keep the core logic of CuTe manipulation inside pycute module. This PR create the main body of the code and then next PR will come with actual implementation and unit test for device mesh layout. (Actual implementation can be found in https://github.com/pytorch/pytorch/pull/161016)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162414
Approved by: https://github.com/ezyang, https://github.com/fegin
ghstack dependencies: #162413, #162534
2025-09-15 17:04:41 +00:00

72 lines
2.1 KiB
Python

"""
Definition of CuTe inspired Layouts for DeviceMesh internal bookkeeping and functions to manipulate them
"""
import math
from collections.abc import Iterator
from dataclasses import dataclass
from torch.distributed._pycute import (
coalesce,
complement,
composition,
flatten,
IntTuple,
is_int,
is_tuple,
Layout,
)
@dataclass(frozen=True, init=True)
class _MeshLayout(Layout):
shape: IntTuple
stride: IntTuple
def __post_init__(self) -> None:
if not is_tuple(self.shape) and not is_int(self.shape):
raise TypeError(f"shape must be a tuple or int, got {type(self.shape)}")
if not is_tuple(self.stride) and not is_int(self.stride):
raise TypeError(f"stride must be a tuple or int, got {type(self.stride)}")
if (
is_tuple(self.shape)
and is_tuple(self.stride)
and len(flatten(self.shape)) != len(flatten(self.stride))
):
raise ValueError(
f"sizes {len(flatten(self.shape))} and "
f"strides {len(flatten(self.stride))} must have the same length"
)
@property
def sizes(self) -> IntTuple:
return self.shape
@property
def strides(self) -> IntTuple:
return self.stride
@property
def sizes_and_strides(self) -> Iterator[tuple[int, int]]:
return zip(flatten(self.shape), flatten(self.stride))
def numel(self) -> int:
return math.prod(flatten(self.shape))
# # operator [] (get-i like tuples)
def __getitem__(self, i: int) -> "_MeshLayout":
layout = super().__getitem__(i)
return _MeshLayout(layout.shape, layout.stride)
def coalesce(self) -> "_MeshLayout":
layout = coalesce(self)
return _MeshLayout(layout.shape, layout.stride)
def composition(self, layout: "_MeshLayout") -> "_MeshLayout":
result = composition(self, layout)
return _MeshLayout(result.shape, result.stride)
def complement(self, world_size: int) -> "_MeshLayout":
layout = complement(self, world_size)
return _MeshLayout(layout.shape, layout.stride)