mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
We create a wrapper class named "_MeshLayout" acting as a layout for device mesh so that we can add new methods more specific to DeviceMesh and keep the core logic of CuTe manipulation inside pycute module. This PR create the main body of the code and then next PR will come with actual implementation and unit test for device mesh layout. (Actual implementation can be found in https://github.com/pytorch/pytorch/pull/161016) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162414 Approved by: https://github.com/ezyang, https://github.com/fegin ghstack dependencies: #162413, #162534
72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
"""
|
|
Definition of CuTe inspired Layouts for DeviceMesh internal bookkeeping and functions to manipulate them
|
|
"""
|
|
|
|
import math
|
|
from collections.abc import Iterator
|
|
from dataclasses import dataclass
|
|
|
|
from torch.distributed._pycute import (
|
|
coalesce,
|
|
complement,
|
|
composition,
|
|
flatten,
|
|
IntTuple,
|
|
is_int,
|
|
is_tuple,
|
|
Layout,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True, init=True)
|
|
class _MeshLayout(Layout):
|
|
shape: IntTuple
|
|
stride: IntTuple
|
|
|
|
def __post_init__(self) -> None:
|
|
if not is_tuple(self.shape) and not is_int(self.shape):
|
|
raise TypeError(f"shape must be a tuple or int, got {type(self.shape)}")
|
|
if not is_tuple(self.stride) and not is_int(self.stride):
|
|
raise TypeError(f"stride must be a tuple or int, got {type(self.stride)}")
|
|
if (
|
|
is_tuple(self.shape)
|
|
and is_tuple(self.stride)
|
|
and len(flatten(self.shape)) != len(flatten(self.stride))
|
|
):
|
|
raise ValueError(
|
|
f"sizes {len(flatten(self.shape))} and "
|
|
f"strides {len(flatten(self.stride))} must have the same length"
|
|
)
|
|
|
|
@property
|
|
def sizes(self) -> IntTuple:
|
|
return self.shape
|
|
|
|
@property
|
|
def strides(self) -> IntTuple:
|
|
return self.stride
|
|
|
|
@property
|
|
def sizes_and_strides(self) -> Iterator[tuple[int, int]]:
|
|
return zip(flatten(self.shape), flatten(self.stride))
|
|
|
|
def numel(self) -> int:
|
|
return math.prod(flatten(self.shape))
|
|
|
|
# # operator [] (get-i like tuples)
|
|
def __getitem__(self, i: int) -> "_MeshLayout":
|
|
layout = super().__getitem__(i)
|
|
return _MeshLayout(layout.shape, layout.stride)
|
|
|
|
def coalesce(self) -> "_MeshLayout":
|
|
layout = coalesce(self)
|
|
return _MeshLayout(layout.shape, layout.stride)
|
|
|
|
def composition(self, layout: "_MeshLayout") -> "_MeshLayout":
|
|
result = composition(self, layout)
|
|
return _MeshLayout(result.shape, result.stride)
|
|
|
|
def complement(self, world_size: int) -> "_MeshLayout":
|
|
layout = complement(self, world_size)
|
|
return _MeshLayout(layout.shape, layout.stride)
|