mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547 Approved by: https://github.com/kwen2501
138 lines
4.5 KiB
Python
138 lines
4.5 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
import itertools
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._utils import _get_device_module
|
|
from torch.distributed import distributed_c10d
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
Shard,
|
|
ShardedTensor,
|
|
ShardedTensorMetadata,
|
|
TensorProperties,
|
|
)
|
|
from torch.distributed._shard.sharding_spec import ShardMetadata
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
|
|
|
|
|
def _get_remote_device_str(rank, device_type, num_devices_per_node):
|
|
if device_type.lower() == "cpu":
|
|
return f"rank:{rank}/{device_type}"
|
|
elif device_type.lower() == "hpu":
|
|
return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
|
|
else:
|
|
return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
|
|
|
|
|
|
def _create_chunk_sharded_tensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
) -> ShardedTensor:
|
|
"""
|
|
Shard a tensor to chunks along the first dimension. The local rank will gets its
|
|
corresponding chunk as the local shard to create a ShardedTensor.
|
|
"""
|
|
chunks = tensor.chunk(world_size, dim=0)
|
|
if len(chunks) > rank:
|
|
local_shard = chunks[rank].clone()
|
|
offsets = [0 for _ in tensor.size()]
|
|
offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
|
|
local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
|
|
else:
|
|
local_shards = []
|
|
|
|
# Create a ShardedTensor without invoking communication.
|
|
chunk_sizes = [list(chunk.size()) for chunk in chunks]
|
|
dim0_offsets = [0] + list(
|
|
itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
|
|
)[:-1]
|
|
offsets = [0] * (len(chunk_sizes[0]) - 1)
|
|
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
|
|
device_type = (
|
|
distributed_c10d._get_pg_default_device(pg).type
|
|
if device is None
|
|
else device.type
|
|
)
|
|
placements = [
|
|
_get_remote_device_str(
|
|
dist.get_global_rank(pg, r),
|
|
device_type,
|
|
num_devices_per_node,
|
|
)
|
|
for r in range(len(chunk_sizes))
|
|
]
|
|
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
|
|
shard_metadata = [
|
|
ShardMetadata(offset, size, placement)
|
|
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
|
|
]
|
|
sharded_tensor_metadata = ShardedTensorMetadata(
|
|
shards_metadata=shard_metadata,
|
|
size=tensor.size(),
|
|
tensor_properties=TensorProperties(
|
|
dtype=tensor.dtype,
|
|
layout=tensor.layout,
|
|
requires_grad=False,
|
|
memory_format=torch.contiguous_format,
|
|
pin_memory=tensor.is_pinned(),
|
|
),
|
|
)
|
|
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
|
|
)
|
|
|
|
|
|
def _create_chunk_dtensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
device_mesh: DeviceMesh,
|
|
) -> DTensor:
|
|
"""
|
|
Shard a tensor to chunks along the first dimension. The local rank will gets its
|
|
corresponding chunk as the local tensor to create a DTensor.
|
|
"""
|
|
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
|
|
tensor = tensor.detach().clone()
|
|
|
|
# FSDP placements: [Shard(0)]
|
|
# HSDP placements: [Replicate(), Shard(0)]
|
|
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
|
|
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
|
|
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
|
|
|
|
return DTensor.from_local(
|
|
tensor, device_mesh, replicate_placements, run_check=False
|
|
).redistribute(
|
|
placements=shard_placements,
|
|
)
|
|
|
|
|
|
def _all_gather_dtensor(
|
|
tensor: DTensor,
|
|
root_mesh: Optional[DeviceMesh],
|
|
) -> torch.Tensor:
|
|
"""
|
|
All gather a DTensor in its sharded dimension and return the local tensor.
|
|
"""
|
|
assert root_mesh == tensor.device_mesh, (
|
|
"The device mesh of a tensor should be a root mesh."
|
|
)
|
|
|
|
placements = list(copy.deepcopy(tensor.placements))
|
|
# FSDP placements: [Shard(0)] -> [Replicate()]
|
|
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
|
placements[-1] = Replicate()
|
|
tensor = tensor.redistribute(
|
|
device_mesh=tensor.device_mesh,
|
|
placements=placements,
|
|
)
|
|
|
|
return tensor.to_local()
|