Files
pytorch/torch/distributed/fsdp/_shard_utils.py

138 lines
4.5 KiB
Python

# mypy: allow-untyped-defs
import copy
import itertools
import math
from typing import Optional
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
TensorProperties,
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
def _get_remote_device_str(rank, device_type, num_devices_per_node):
if device_type.lower() == "cpu":
return f"rank:{rank}/{device_type}"
elif device_type.lower() == "hpu":
return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
else:
return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
def _create_chunk_sharded_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> ShardedTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local shard to create a ShardedTensor.
"""
chunks = tensor.chunk(world_size, dim=0)
if len(chunks) > rank:
local_shard = chunks[rank].clone()
offsets = [0 for _ in tensor.size()]
offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
else:
local_shards = []
# Create a ShardedTensor without invoking communication.
chunk_sizes = [list(chunk.size()) for chunk in chunks]
dim0_offsets = [0] + list(
itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
)[:-1]
offsets = [0] * (len(chunk_sizes[0]) - 1)
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
device_type = (
distributed_c10d._get_pg_default_device(pg).type
if device is None
else device.type
)
placements = [
_get_remote_device_str(
dist.get_global_rank(pg, r),
device_type,
num_devices_per_node,
)
for r in range(len(chunk_sizes))
]
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
shard_metadata = [
ShardMetadata(offset, size, placement)
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
]
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=tensor.size(),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=False,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
)
def _create_chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> DTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local tensor to create a DTensor.
"""
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
tensor = tensor.detach().clone()
# FSDP placements: [Shard(0)]
# HSDP placements: [Replicate(), Shard(0)]
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
return DTensor.from_local(
tensor, device_mesh, replicate_placements, run_check=False
).redistribute(
placements=shard_placements,
)
def _all_gather_dtensor(
tensor: DTensor,
root_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
All gather a DTensor in its sharded dimension and return the local tensor.
"""
assert root_mesh == tensor.device_mesh, (
"The device mesh of a tensor should be a root mesh."
)
placements = list(copy.deepcopy(tensor.placements))
# FSDP placements: [Shard(0)] -> [Replicate()]
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
placements[-1] = Replicate()
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh,
placements=placements,
)
return tensor.to_local()