mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updates mypy to 1.11.1 to improve type inference Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816 Approved by: https://github.com/ezyang
317 lines
11 KiB
Python
317 lines
11 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Any, List, Tuple
|
|
|
|
import torch
|
|
from torch.distributed.checkpoint.metadata import (
|
|
ChunkStorageMetadata,
|
|
MetadataIndex,
|
|
TensorProperties,
|
|
TensorStorageMetadata,
|
|
)
|
|
from torch.distributed.checkpoint.planner import (
|
|
TensorWriteData,
|
|
WriteItem,
|
|
WriteItemType,
|
|
)
|
|
|
|
|
|
aten = (
|
|
torch.ops.aten
|
|
) # pyre-ignore[5]: Globally accessible variable `aten` has no type specified.
|
|
|
|
|
|
class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
|
|
"""
|
|
A wrapper class to hold local shards of a DTensor.
|
|
This class is used largely for checkpointing purposes and implicity subtypes
|
|
the _Checkpointable protocol.
|
|
"""
|
|
|
|
__slots__ = ["_local_shards", "_storage_meta"]
|
|
_local_shards: List[torch.Tensor]
|
|
_storage_meta: TensorStorageMetadata
|
|
|
|
@staticmethod
|
|
def __new__(
|
|
cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]]
|
|
) -> "LocalShardsWrapper":
|
|
assert len(local_shards) > 0
|
|
assert len(local_shards) == len(local_offsets)
|
|
assert all(
|
|
tensor.device == local_shards[0].device for tensor in local_shards[1:]
|
|
)
|
|
|
|
# we calculate the total tensor size by "concat" on second tensor dimension
|
|
cat_tensor_shape = list(local_shards[0].size())
|
|
if len(local_shards) > 1: # column-wise sharding
|
|
for shard in local_shards[1:]:
|
|
cat_tensor_shape[1] += shard.size()[1]
|
|
|
|
wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
|
|
wrapper_shape = torch.Size(cat_tensor_shape)
|
|
chunks_meta = [
|
|
ChunkStorageMetadata(
|
|
offsets=torch.Size(offset),
|
|
sizes=shard.size(),
|
|
)
|
|
for shard, offset in zip(local_shards, local_offsets)
|
|
]
|
|
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
cls,
|
|
torch.Size(cat_tensor_shape),
|
|
)
|
|
r._local_shards = local_shards
|
|
r._storage_meta = TensorStorageMetadata(
|
|
properties=wrapper_properties,
|
|
size=wrapper_shape,
|
|
chunks=chunks_meta,
|
|
)
|
|
|
|
return r
|
|
|
|
# necessary for ops dispatching from this subclass to its local shards
|
|
@classmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
dispatcher = {
|
|
torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor,
|
|
torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor,
|
|
aten._to_copy.default: cls.handle_to_copy,
|
|
aten.view.default: cls.handle_view,
|
|
aten.equal.default: cls.handle_equal,
|
|
aten.detach.default: cls.handle_detach,
|
|
aten.clone.default: cls.handle_clone,
|
|
}
|
|
|
|
if func in dispatcher:
|
|
return dispatcher[func](
|
|
args, kwargs
|
|
) # pyre-ignore [29] - `Variable[_VT]` is not a function.
|
|
else:
|
|
raise NotImplementedError(
|
|
f"{func} is not supported for LocalShardsWrapper!"
|
|
)
|
|
|
|
@staticmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def handle_all_gather_into_tensor(args, kwargs):
|
|
dim = args[0].local_sizes()[0][1]
|
|
cat_tensor = torch.cat(
|
|
[t.view(-1) for t in args[0].local_shards()], dim=0
|
|
).view(-1, dim)
|
|
return torch.ops._c10d_functional.all_gather_into_tensor.default(
|
|
cat_tensor, *args[1:], **kwargs
|
|
)
|
|
|
|
@staticmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def handle_wait_tensor(args, kwargs):
|
|
return torch.ops._c10d_functional.wait_tensor(args[0])
|
|
|
|
@staticmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def handle_to_copy(args, kwargs):
|
|
res_shards_list = [
|
|
aten._to_copy.default(shard, *args[1:], **kwargs)
|
|
for shard in args[0].local_shards()
|
|
]
|
|
return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
|
|
|
|
@staticmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def handle_view(args, kwargs):
|
|
# TODO, do we need to change the shape of associated offsets?
|
|
res_shards_list = [
|
|
aten.view.default(shard, args[1], **kwargs)
|
|
for shard in args[0].local_shards()
|
|
]
|
|
return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
|
|
|
|
@staticmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def handle_equal(args, kwargs):
|
|
"""
|
|
LocalShardsWrapper equal impl also checks for equality of storage metadata
|
|
and the order of shards
|
|
"""
|
|
a, b = args[0], args[1]
|
|
if len(a.local_shards()) != len(b.local_shards()):
|
|
return False
|
|
if not all(
|
|
aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards())
|
|
):
|
|
return False
|
|
if not a.storage_metadata() == b.storage_metadata():
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def handle_detach(args, kwargs):
|
|
self_ls = args[0]
|
|
deatched_local_shards = [
|
|
aten.detach.default(shard) for shard in self_ls.local_shards()
|
|
]
|
|
self_ls._local_shards = deatched_local_shards
|
|
self_ls._storage_meta.properties.requires_grad = False
|
|
return self_ls
|
|
|
|
@staticmethod
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
def handle_clone(args, kwargs):
|
|
self_ls = args[0]
|
|
desired_memory_format = kwargs.get("memory_format", None)
|
|
if desired_memory_format and desired_memory_format != torch.preserve_format:
|
|
raise NotImplementedError(
|
|
f"{desired_memory_format} is not supported for LocalShardsWrapper!"
|
|
)
|
|
cloned_local_shards = [
|
|
shard.clone(memory_format=desired_memory_format)
|
|
for shard in self_ls._local_shards
|
|
]
|
|
return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())
|
|
|
|
@property
|
|
def device(self) -> torch._C.device: # type: ignore[override]
|
|
return self._local_shards[0].device
|
|
|
|
@property
|
|
def is_meta(self) -> bool: # type: ignore[override]
|
|
return self._local_shards[0].is_meta
|
|
|
|
# pyre-ignore[14]
|
|
def is_pinned(self) -> bool: # type: ignore[override]
|
|
return self._storage_meta.properties.pin_memory
|
|
|
|
# pyre-ignore[14]
|
|
def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper":
|
|
self._storage_meta.properties.requires_grad = requires_grad
|
|
[shard.requires_grad_(requires_grad) for shard in self._local_shards]
|
|
return self
|
|
|
|
def local_shards(self) -> List[torch.Tensor]:
|
|
"""
|
|
Returns a list of :class:`torch.Tensor' corresponding to the
|
|
local shards for this rank. Returns an empty list if the current rank
|
|
does not host any shards for this Tensor.
|
|
"""
|
|
return self._local_shards
|
|
|
|
def local_sizes(self) -> List[torch.Size]:
|
|
"""
|
|
Returns a list of :class:`torch.Size' corresponding to the
|
|
local sizes for the shards on this rank. Returns an empty list if the current rank
|
|
does not host any shards for this Tensor.
|
|
"""
|
|
return [chunk.sizes for chunk in self._storage_meta.chunks]
|
|
|
|
def local_offsets(self) -> List[torch.Size]:
|
|
"""
|
|
Returns a list of :class:`torch.Size' corresponding to the
|
|
local offsets for the shards on this rank. Returns an empty list if the current rank
|
|
does not host any shards for this Tensor.
|
|
"""
|
|
return [chunk.offsets for chunk in self._storage_meta.chunks]
|
|
|
|
@property
|
|
def local_chunks(self) -> List[ChunkStorageMetadata]:
|
|
"""
|
|
Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the
|
|
metadata for each tensor shard
|
|
"""
|
|
return self._storage_meta.chunks
|
|
|
|
def storage_metadata(self) -> TensorStorageMetadata:
|
|
"""
|
|
Returns a :class:`TensorStorageMetadata` object corresponding to the
|
|
metadata for the local tensor on current rank
|
|
"""
|
|
return self._storage_meta
|
|
|
|
def __create_write_items__(
|
|
self, fqn: str, object: Any
|
|
) -> List[WriteItem]: # pyre-ignore[2]
|
|
"""
|
|
For compatibility with DCP, we support creation of WriteItems
|
|
such that they can be saved properly.
|
|
"""
|
|
return [
|
|
WriteItem(
|
|
index=MetadataIndex(fqn, chunks.offsets),
|
|
type=WriteItemType.SHARD,
|
|
tensor_data=TensorWriteData(
|
|
chunk=ChunkStorageMetadata(
|
|
offsets=chunks.offsets,
|
|
sizes=chunks.sizes,
|
|
),
|
|
properties=self._storage_meta.properties,
|
|
size=object.size(),
|
|
),
|
|
)
|
|
for tensor, chunks in zip(self.local_shards(), self.local_chunks)
|
|
]
|
|
|
|
def __create_chunk_list__(self) -> List[ChunkStorageMetadata]:
|
|
"""
|
|
For compatibility with DCP, we support creation of chunk lists
|
|
such that they can be saved properly.
|
|
"""
|
|
return self._storage_meta.chunks
|
|
|
|
def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
|
|
"""
|
|
For compatibility with DCP, we support finding shard based on index
|
|
Return a 'torch.Tensor' shard based on 'MetadataIndex'.
|
|
"""
|
|
# Fast lookup path
|
|
if index.index is not None:
|
|
if (
|
|
len(self._local_shards) > index.index
|
|
and self._storage_meta.chunks[index.index].offsets == index.offset
|
|
):
|
|
return self._local_shards[index.index]
|
|
|
|
if index.offset is not None:
|
|
for shard, chunk in zip(self._local_shards, self._storage_meta.chunks):
|
|
if chunk.offsets == index.offset:
|
|
return shard
|
|
|
|
raise ValueError(
|
|
f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
|
|
)
|
|
|
|
def _get_tensor_size_bytes(self) -> int:
|
|
object_size = 0
|
|
for shard in self.local_shards():
|
|
object_size += shard.nelement() * shard.element_size()
|
|
return object_size
|
|
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
def __hash__(self):
|
|
return id(self)
|
|
|
|
# pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently.
|
|
# pyre-fixme[3]: Return type must be annotated.
|
|
def __repr__(self) -> str: # type: ignore[override]
|
|
return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
|
|
|
|
def __str__(self) -> str:
|
|
return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
|