mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE][PYFMT] migrate PYFMT for torch.{distributed,distributions}
to ruff format
(#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e160d5fd9
commit
995df34b19
@ -59,7 +59,6 @@ USE_BLACK_FILELIST = re.compile(
|
|||||||
# torch/[a-c]*/**
|
# torch/[a-c]*/**
|
||||||
"torch/[a-c]*/**",
|
"torch/[a-c]*/**",
|
||||||
# torch/d*/**
|
# torch/d*/**
|
||||||
"torch/d*/**",
|
|
||||||
# torch/[e-n]*/**
|
# torch/[e-n]*/**
|
||||||
"torch/[e-n]*/**",
|
"torch/[e-n]*/**",
|
||||||
# torch/optim/**
|
# torch/optim/**
|
||||||
|
@ -36,11 +36,9 @@ _M = TypeVar("_M", nn.Module, list[nn.Module])
|
|||||||
|
|
||||||
|
|
||||||
class _ContractFn(Protocol, Generic[_P, _T, _TState]):
|
class _ContractFn(Protocol, Generic[_P, _T, _TState]):
|
||||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
|
||||||
...
|
|
||||||
|
|
||||||
def state(self, module: nn.Module) -> _TState:
|
def state(self, module: nn.Module) -> _TState: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def contract(
|
def contract(
|
||||||
@ -92,7 +90,7 @@ def contract(
|
|||||||
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
|
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
|
||||||
@wraps(state_cls) # type: ignore[arg-type]
|
@wraps(state_cls) # type: ignore[arg-type]
|
||||||
def inner(
|
def inner(
|
||||||
func: Callable[Concatenate[_M, _P], _M]
|
func: Callable[Concatenate[_M, _P], _M],
|
||||||
) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
|
) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(
|
def wrapper(
|
||||||
@ -232,9 +230,7 @@ def contract(
|
|||||||
return module.__dict__.setdefault( # type: ignore[call-overload]
|
return module.__dict__.setdefault( # type: ignore[call-overload]
|
||||||
STATE_KEY,
|
STATE_KEY,
|
||||||
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
|
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
|
||||||
).get(
|
).get(func) # type: ignore[call-overload]
|
||||||
func
|
|
||||||
) # type: ignore[call-overload]
|
|
||||||
|
|
||||||
wrapper.state = get_state # type: ignore[attr-defined]
|
wrapper.state = get_state # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
@ -274,9 +274,9 @@ def reduce_scatter_tensor(
|
|||||||
group_name = _resolve_group_name(group, tag)
|
group_name = _resolve_group_name(group, tag)
|
||||||
group_size = c10d._get_group_size_by_name(group_name)
|
group_size = c10d._get_group_size_by_name(group_name)
|
||||||
|
|
||||||
assert (
|
assert self.size(scatter_dim) % group_size == 0, (
|
||||||
self.size(scatter_dim) % group_size == 0
|
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||||
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
)
|
||||||
if scatter_dim != 0:
|
if scatter_dim != 0:
|
||||||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||||
self = torch.cat(tensor_list)
|
self = torch.cat(tensor_list)
|
||||||
@ -313,9 +313,9 @@ def reduce_scatter_tensor_autograd(
|
|||||||
group_name = _resolve_group_name(group, tag)
|
group_name = _resolve_group_name(group, tag)
|
||||||
group_size = c10d._get_group_size_by_name(group_name)
|
group_size = c10d._get_group_size_by_name(group_name)
|
||||||
|
|
||||||
assert (
|
assert self.size(scatter_dim) % group_size == 0, (
|
||||||
self.size(scatter_dim) % group_size == 0
|
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||||
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
)
|
||||||
if scatter_dim != 0:
|
if scatter_dim != 0:
|
||||||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||||
self = torch.cat(tensor_list)
|
self = torch.cat(tensor_list)
|
||||||
@ -414,9 +414,9 @@ def reduce_scatter_tensor_coalesced(
|
|||||||
|
|
||||||
assert len(scatter_dim) == len(inputs)
|
assert len(scatter_dim) == len(inputs)
|
||||||
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
|
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
|
||||||
assert (
|
assert tensor.size(dim) % group_size == 0, (
|
||||||
tensor.size(dim) % group_size == 0
|
f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
|
||||||
), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
|
)
|
||||||
if dim != 0:
|
if dim != 0:
|
||||||
tensor_list = torch.chunk(tensor, group_size, dim=dim)
|
tensor_list = torch.chunk(tensor, group_size, dim=dim)
|
||||||
inputs[idx] = torch.cat(tensor_list)
|
inputs[idx] = torch.cat(tensor_list)
|
||||||
@ -574,6 +574,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
|||||||
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
|
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
|
||||||
return _maybe_wrap_tensor(tensor)
|
return _maybe_wrap_tensor(tensor)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
elem: torch.Tensor
|
elem: torch.Tensor
|
||||||
completed: bool
|
completed: bool
|
||||||
|
|
||||||
@ -726,9 +727,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
|
|||||||
group_size = len(rankset)
|
group_size = len(rankset)
|
||||||
tag = tag or c10d._get_group_tag(group)
|
tag = tag or c10d._get_group_tag(group)
|
||||||
elif isinstance(group, DeviceMesh):
|
elif isinstance(group, DeviceMesh):
|
||||||
assert (
|
assert group.ndim == 1, (
|
||||||
group.ndim == 1
|
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||||
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
)
|
||||||
# TODO: it should run collective in the whole mesh instead of dim 0
|
# TODO: it should run collective in the whole mesh instead of dim 0
|
||||||
tag, rankset, _ = group._dim_group_infos[0]
|
tag, rankset, _ = group._dim_group_infos[0]
|
||||||
group_size = len(rankset)
|
group_size = len(rankset)
|
||||||
@ -763,9 +764,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
|
|||||||
elif isinstance(group, str):
|
elif isinstance(group, str):
|
||||||
return group
|
return group
|
||||||
elif isinstance(group, DeviceMesh):
|
elif isinstance(group, DeviceMesh):
|
||||||
assert (
|
assert group.ndim == 1, (
|
||||||
group.ndim == 1
|
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||||
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
)
|
||||||
return group._dim_group_infos[0][2]
|
return group._dim_group_infos[0][2]
|
||||||
elif isinstance(group, tuple):
|
elif isinstance(group, tuple):
|
||||||
if (
|
if (
|
||||||
@ -837,11 +838,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
|
|||||||
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
|
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(fullgraph=True)
|
@torch.compile(fullgraph=True)
|
||||||
def all_reduce_wait_compiled(y):
|
def all_reduce_wait_compiled(y):
|
||||||
torch.ops.c10d_functional.wait_tensor(y)
|
torch.ops.c10d_functional.wait_tensor(y)
|
||||||
return y * y
|
return y * y
|
||||||
|
|
||||||
|
|
||||||
x = torch.ones(1280, 1280, device="cuda") + self.rank
|
x = torch.ones(1280, 1280, device="cuda") + self.rank
|
||||||
# the context manager ensures that `wait_tensor(y)` will wait on the correct work object
|
# the context manager ensures that `wait_tensor(y)` will wait on the correct work object
|
||||||
with allow_inflight_collective_as_graph_input_ctx():
|
with allow_inflight_collective_as_graph_input_ctx():
|
||||||
@ -1057,9 +1060,9 @@ def all_gather_tensor_inplace(
|
|||||||
tag: str = "",
|
tag: str = "",
|
||||||
gather_dim: int = 0,
|
gather_dim: int = 0,
|
||||||
):
|
):
|
||||||
assert (
|
assert not async_op, (
|
||||||
not async_op
|
"Can't remap async version of inplace op to functional collective"
|
||||||
), "Can't remap async version of inplace op to functional collective"
|
)
|
||||||
|
|
||||||
group = group or dist.group.WORLD
|
group = group or dist.group.WORLD
|
||||||
assert group is not None
|
assert group is not None
|
||||||
@ -1076,9 +1079,9 @@ def reduce_scatter_tensor_inplace(
|
|||||||
scatter_dim: int = 0,
|
scatter_dim: int = 0,
|
||||||
tag: str = "",
|
tag: str = "",
|
||||||
):
|
):
|
||||||
assert (
|
assert not async_op, (
|
||||||
not async_op
|
"Can't remap async version of inplace op to functional collective"
|
||||||
), "Can't remap async version of inplace op to functional collective"
|
)
|
||||||
|
|
||||||
group = group or dist.group.WORLD
|
group = group or dist.group.WORLD
|
||||||
assert group is not None
|
assert group is not None
|
||||||
@ -1105,9 +1108,9 @@ def all_reduce_inplace(
|
|||||||
async_op: bool = False,
|
async_op: bool = False,
|
||||||
tag: str = "",
|
tag: str = "",
|
||||||
):
|
):
|
||||||
assert (
|
assert not async_op, (
|
||||||
not async_op
|
"Can't remap async version of inplace op to functional collective"
|
||||||
), "Can't remap async version of inplace op to functional collective"
|
)
|
||||||
|
|
||||||
group = group or dist.group.WORLD
|
group = group or dist.group.WORLD
|
||||||
assert group is not None
|
assert group is not None
|
||||||
@ -1124,9 +1127,9 @@ def all_to_all_inplace(
|
|||||||
async_op=False,
|
async_op=False,
|
||||||
tag: str = "",
|
tag: str = "",
|
||||||
):
|
):
|
||||||
assert (
|
assert not async_op, (
|
||||||
not async_op
|
"Can't remap async version of inplace op to functional collective"
|
||||||
), "Can't remap async version of inplace op to functional collective"
|
)
|
||||||
|
|
||||||
group = group or dist.group.WORLD
|
group = group or dist.group.WORLD
|
||||||
assert group is not None
|
assert group is not None
|
||||||
@ -1149,12 +1152,12 @@ def all_gather_inplace(
|
|||||||
async_op=False,
|
async_op=False,
|
||||||
tag: str = "",
|
tag: str = "",
|
||||||
):
|
):
|
||||||
assert (
|
assert not async_op, (
|
||||||
not async_op
|
"Can't remap async version of inplace op to functional collective"
|
||||||
), "Can't remap async version of inplace op to functional collective"
|
)
|
||||||
assert all(
|
assert all(t.size(0) == tensor.size(0) for t in tensor_list), (
|
||||||
t.size(0) == tensor.size(0) for t in tensor_list
|
"Remapping variable size all_gather is not yet supported"
|
||||||
), "Remapping variable size all_gather is not yet supported"
|
)
|
||||||
|
|
||||||
group = group or dist.group.WORLD
|
group = group or dist.group.WORLD
|
||||||
assert group is not None
|
assert group is not None
|
||||||
|
@ -592,7 +592,9 @@ class ShardedTensor(ShardedTensorBase):
|
|||||||
assert (
|
assert (
|
||||||
isinstance(device, torch.device)
|
isinstance(device, torch.device)
|
||||||
and device.index == torch.cuda.current_device()
|
and device.index == torch.cuda.current_device()
|
||||||
), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
|
), (
|
||||||
|
"""Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
|
||||||
|
)
|
||||||
|
|
||||||
current_device = torch.device(torch.cuda.current_device())
|
current_device = torch.device(torch.cuda.current_device())
|
||||||
# returns a copy of ShardedTensor on CUDA current device
|
# returns a copy of ShardedTensor on CUDA current device
|
||||||
@ -831,7 +833,9 @@ class ShardedTensor(ShardedTensorBase):
|
|||||||
"rank:1/cuda:1",
|
"rank:1/cuda:1",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
>>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4])
|
>>> st = ShardedTensor._init_from_local_tensor(
|
||||||
|
... local_tensor, sharding_spec, [2, 4]
|
||||||
|
... )
|
||||||
>>> st
|
>>> st
|
||||||
ShardedTensor(
|
ShardedTensor(
|
||||||
ShardedTensorMetadata(
|
ShardedTensorMetadata(
|
||||||
|
@ -219,9 +219,7 @@ def reshard_local_shard(
|
|||||||
output_tensor_size = list(st_size)
|
output_tensor_size = list(st_size)
|
||||||
output_tensor_size[current_sharding_dim] = sharded_dim_size
|
output_tensor_size[current_sharding_dim] = sharded_dim_size
|
||||||
output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
|
output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
|
||||||
output_tensor_list[
|
output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index]
|
||||||
placement.rank()
|
|
||||||
] = torch.empty( # type: ignore[union-attr, index]
|
|
||||||
output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
|
output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
|
||||||
)
|
)
|
||||||
indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type]
|
indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type]
|
||||||
|
@ -16,6 +16,6 @@ with warnings.catch_warnings():
|
|||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
sys.modules[
|
sys.modules["torch.distributed._sharded_tensor"] = (
|
||||||
"torch.distributed._sharded_tensor"
|
torch.distributed._shard.sharded_tensor
|
||||||
] = torch.distributed._shard.sharded_tensor
|
)
|
||||||
|
@ -67,7 +67,7 @@ def _all_gather_sharded_tensor(
|
|||||||
|
|
||||||
|
|
||||||
class CompanionMismatch(Exception):
|
class CompanionMismatch(Exception):
|
||||||
...
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _iterate_state_dict(
|
def _iterate_state_dict(
|
||||||
@ -409,9 +409,9 @@ def _create_cpu_state_dict(
|
|||||||
|
|
||||||
def unpin_memory(t):
|
def unpin_memory(t):
|
||||||
succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
|
succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
|
||||||
assert (
|
assert succ == 0, (
|
||||||
succ == 0
|
f"Unpinning shared memory failed with error-code: {succ}"
|
||||||
), f"Unpinning shared memory failed with error-code: {succ}"
|
)
|
||||||
|
|
||||||
weakref.finalize(t, unpin_memory, t)
|
weakref.finalize(t, unpin_memory, t)
|
||||||
succ = int(
|
succ = int(
|
||||||
@ -421,9 +421,9 @@ def _create_cpu_state_dict(
|
|||||||
1, # lines up with 'cudaHostRegisterPortable'
|
1, # lines up with 'cudaHostRegisterPortable'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert (
|
assert succ == 0, (
|
||||||
succ == 0
|
f"Pinning shared memory failed with error-code: {succ}"
|
||||||
), f"Pinning shared memory failed with error-code: {succ}"
|
)
|
||||||
return t
|
return t
|
||||||
elif pin_memory:
|
elif pin_memory:
|
||||||
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
|
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
|
||||||
|
@ -1525,8 +1525,7 @@ if TYPE_CHECKING:
|
|||||||
@overload
|
@overload
|
||||||
def empty(
|
def empty(
|
||||||
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
|
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -1535,8 +1534,7 @@ def empty(
|
|||||||
*,
|
*,
|
||||||
dtype: Optional[_dtype] = None,
|
dtype: Optional[_dtype] = None,
|
||||||
device: Optional[_device] = None,
|
device: Optional[_device] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def empty( # type: ignore[misc]
|
def empty( # type: ignore[misc]
|
||||||
|
@ -6,6 +6,7 @@ we keep the old import path starts with `_tensor` for
|
|||||||
backward compatibility. We will remove this folder once
|
backward compatibility. We will remove this folder once
|
||||||
we resolve all the BC issues.
|
we resolve all the BC issues.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ class FSDPMemTracker(MemTracker):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
fmt.display_snapshot("peak")
|
fmt.display_snapshot("peak")
|
||||||
fmt.display_modulewise_snapshots(depth = 3, units = "MB")
|
fmt.display_modulewise_snapshots(depth=3, units="MB")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -379,7 +379,7 @@ class MemTracker(TorchDispatchMode):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
mt.display_snapshot("peak")
|
mt.display_snapshot("peak")
|
||||||
mt.display_modulewise_snapshots(depth = 3, units = "MiB")
|
mt.display_modulewise_snapshots(depth=3, units="MiB")
|
||||||
|
|
||||||
Known Limitations:
|
Known Limitations:
|
||||||
- The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``.
|
- The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``.
|
||||||
|
@ -42,6 +42,7 @@ class ModTracker:
|
|||||||
def my_linear(m1, m2, bias):
|
def my_linear(m1, m2, bias):
|
||||||
print(f"Current modules: {tracker.parents}")
|
print(f"Current modules: {tracker.parents}")
|
||||||
return torch.mm(m1, m2.t()) + bias
|
return torch.mm(m1, m2.t()) + bias
|
||||||
|
|
||||||
torch.nn.functional.linear = my_linear
|
torch.nn.functional.linear = my_linear
|
||||||
|
|
||||||
mod(torch.rand(2, 2))
|
mod(torch.rand(2, 2))
|
||||||
|
@ -255,9 +255,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
|||||||
Tuple[Any, float]: A tuple containing the result of the function and
|
Tuple[Any, float]: A tuple containing the result of the function and
|
||||||
the mean operation time in milliseconds.
|
the mean operation time in milliseconds.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(cls.fake_mode, FakeTensorMode), (
|
||||||
cls.fake_mode, FakeTensorMode
|
"Initialize/Assign FakeTensorMode before using this function"
|
||||||
), "Initialize/Assign FakeTensorMode before using this function"
|
)
|
||||||
mean_op_time = 0.0
|
mean_op_time = 0.0
|
||||||
if func._overloadpacket not in _VIEW_OPS:
|
if func._overloadpacket not in _VIEW_OPS:
|
||||||
try:
|
try:
|
||||||
@ -289,9 +289,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
|||||||
Tuple[Any, float]: A tuple containing the result of the function and
|
Tuple[Any, float]: A tuple containing the result of the function and
|
||||||
the mean operation time in milliseconds.
|
the mean operation time in milliseconds.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert torch.cuda.is_available(), (
|
||||||
torch.cuda.is_available()
|
"Roofline estimation needs to access CUDA capabilities to make estimations"
|
||||||
), "Roofline estimation needs to access CUDA capabilities to make estimations"
|
)
|
||||||
|
|
||||||
def get_num_bytes(t: torch.Tensor) -> int:
|
def get_num_bytes(t: torch.Tensor) -> int:
|
||||||
"""
|
"""
|
||||||
@ -324,9 +324,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
|||||||
float: The estimated compute time in nanoseconds.
|
float: The estimated compute time in nanoseconds.
|
||||||
"""
|
"""
|
||||||
if func_packet in flop_registry:
|
if func_packet in flop_registry:
|
||||||
assert (
|
assert len(out_dtypes) == 1, (
|
||||||
len(out_dtypes) == 1
|
f"Only support single out dtype got {out_dtypes} for {func_packet}"
|
||||||
), f"Only support single out dtype got {out_dtypes} for {func_packet}"
|
)
|
||||||
dtype = out_dtypes.pop()
|
dtype = out_dtypes.pop()
|
||||||
# This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
|
# This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
|
||||||
peak_gpu_flops = get_device_tflops(dtype) * 1e15
|
peak_gpu_flops = get_device_tflops(dtype) * 1e15
|
||||||
@ -487,9 +487,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
|||||||
|
|
||||||
def __enter__(self) -> Self:
|
def __enter__(self) -> Self:
|
||||||
fake_mode = active_fake_mode()
|
fake_mode = active_fake_mode()
|
||||||
assert isinstance(
|
assert isinstance(fake_mode, FakeTensorMode), (
|
||||||
fake_mode, FakeTensorMode
|
"No FakeTensorMode found, designed to used under FakeTensorMode"
|
||||||
), "No FakeTensorMode found, designed to used under FakeTensorMode"
|
)
|
||||||
RuntimeEstimator.fake_mode = fake_mode
|
RuntimeEstimator.fake_mode = fake_mode
|
||||||
self.total_runtime = 0.0
|
self.total_runtime = 0.0
|
||||||
self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0))
|
self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0))
|
||||||
|
@ -245,7 +245,7 @@ class SACEstimator(TorchDispatchMode):
|
|||||||
with FakeTensorMode():
|
with FakeTensorMode():
|
||||||
module = ...
|
module = ...
|
||||||
inp = ...
|
inp = ...
|
||||||
with sac_estimator('operator-level-cost-model'):
|
with sac_estimator("operator-level-cost-model"):
|
||||||
output = module(inp)
|
output = module(inp)
|
||||||
sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True)
|
sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True)
|
||||||
"""
|
"""
|
||||||
@ -442,9 +442,9 @@ class SACEstimator(TorchDispatchMode):
|
|||||||
out_storages_cpu.update(_get_untyped_storages(o))
|
out_storages_cpu.update(_get_untyped_storages(o))
|
||||||
|
|
||||||
# Check if there's more than 1 CUDA device
|
# Check if there's more than 1 CUDA device
|
||||||
assert (
|
assert len(cuda_devices) <= 1, (
|
||||||
len(cuda_devices) <= 1
|
f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
|
||||||
), f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
|
)
|
||||||
|
|
||||||
# 2. Get the memory consumed by output
|
# 2. Get the memory consumed by output
|
||||||
nbytes_cuda = sum(
|
nbytes_cuda = sum(
|
||||||
@ -484,9 +484,9 @@ class SACEstimator(TorchDispatchMode):
|
|||||||
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
|
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
|
||||||
acm_stats.sac_metadata.append(acm)
|
acm_stats.sac_metadata.append(acm)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert mod_fqn == "Global", (
|
||||||
mod_fqn == "Global"
|
f"Module {mod_fqn} not found in AC Mod Stats"
|
||||||
), f"Module {mod_fqn} not found in AC Mod Stats"
|
)
|
||||||
self._sac_metadata.append(acm)
|
self._sac_metadata.append(acm)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@ -979,9 +979,9 @@ class SACEstimator(TorchDispatchMode):
|
|||||||
|
|
||||||
def __enter__(self) -> Self: # type: ignore[no-untyped-def]
|
def __enter__(self) -> Self: # type: ignore[no-untyped-def]
|
||||||
fake_mode = active_fake_mode()
|
fake_mode = active_fake_mode()
|
||||||
assert isinstance(
|
assert isinstance(fake_mode, FakeTensorMode), (
|
||||||
fake_mode, FakeTensorMode
|
"SAC Estimator should be called in FakeTensorMode"
|
||||||
), "SAC Estimator should be called in FakeTensorMode"
|
)
|
||||||
RuntimeEstimator.fake_mode = fake_mode
|
RuntimeEstimator.fake_mode = fake_mode
|
||||||
self._mod_tracker.register_user_hooks(
|
self._mod_tracker.register_user_hooks(
|
||||||
pre_fw_hook=self._pre_fw_hook,
|
pre_fw_hook=self._pre_fw_hook,
|
||||||
|
@ -38,9 +38,9 @@ def _perform_local_step(
|
|||||||
"""
|
"""
|
||||||
overlap_info = zero._overlap_info
|
overlap_info = zero._overlap_info
|
||||||
bucket_index = bucket.index()
|
bucket_index = bucket.index()
|
||||||
assert (
|
assert len(zero.optim.param_groups) == 1, (
|
||||||
len(zero.optim.param_groups) == 1
|
"Overlapping DDP with ZeRO only supports a single parameter group"
|
||||||
), "Overlapping DDP with ZeRO only supports a single parameter group"
|
)
|
||||||
|
|
||||||
# Construct the `gradients` input for the local optimizer step, which
|
# Construct the `gradients` input for the local optimizer step, which
|
||||||
# expects `None` in a list position to indicate that the corresponding
|
# expects `None` in a list position to indicate that the corresponding
|
||||||
@ -49,9 +49,9 @@ def _perform_local_step(
|
|||||||
gradients: list[Optional[torch.Tensor]] = [
|
gradients: list[Optional[torch.Tensor]] = [
|
||||||
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
|
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
|
||||||
]
|
]
|
||||||
assert (
|
assert bucket_index in overlap_info.offsets, (
|
||||||
bucket_index in overlap_info.offsets
|
f"Bucket index {bucket_index} was not assigned to rank {rank}"
|
||||||
), f"Bucket index {bucket_index} was not assigned to rank {rank}"
|
)
|
||||||
gradients_offset = overlap_info.offsets[bucket_index]
|
gradients_offset = overlap_info.offsets[bucket_index]
|
||||||
bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
|
bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
|
||||||
bucket_offset = bucket_assignment.offset
|
bucket_offset = bucket_assignment.offset
|
||||||
@ -77,13 +77,13 @@ def _broadcast_bucket(
|
|||||||
:class:`ZeroRedundancyOptimizer` instance.
|
:class:`ZeroRedundancyOptimizer` instance.
|
||||||
"""
|
"""
|
||||||
overlap_info = zero._overlap_info
|
overlap_info = zero._overlap_info
|
||||||
assert (
|
assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
|
||||||
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
|
"`assigned_ranks_per_bucket` is not fully constructed"
|
||||||
), "`assigned_ranks_per_bucket` is not fully constructed"
|
)
|
||||||
# Sort to ensure the same ordering across ranks
|
# Sort to ensure the same ordering across ranks
|
||||||
assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
|
assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
|
||||||
assert len(assigned_ranks) > 0, (
|
assert len(assigned_ranks) > 0, (
|
||||||
f"Bucket {bucket_index} should be " "assigned to at least one rank"
|
f"Bucket {bucket_index} should be assigned to at least one rank"
|
||||||
)
|
)
|
||||||
for assigned_rank in assigned_ranks:
|
for assigned_rank in assigned_ranks:
|
||||||
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
|
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
|
||||||
@ -273,9 +273,9 @@ def hook_with_zero_step(
|
|||||||
rank = zero.global_rank
|
rank = zero.global_rank
|
||||||
|
|
||||||
assert overlap_info.status == _OverlapStatus.INITIALIZED
|
assert overlap_info.status == _OverlapStatus.INITIALIZED
|
||||||
assert (
|
assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
|
||||||
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
|
"`assigned_ranks_per_bucket` is not fully constructed"
|
||||||
), "`assigned_ranks_per_bucket` is not fully constructed"
|
)
|
||||||
assigned_to_bucket = (
|
assigned_to_bucket = (
|
||||||
rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
|
rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
|
||||||
)
|
)
|
||||||
@ -288,9 +288,9 @@ def hook_with_zero_step(
|
|||||||
# Check that buckets are indexed incrementally starting from 0 in the
|
# Check that buckets are indexed incrementally starting from 0 in the
|
||||||
# order of their autograd hooks firing
|
# order of their autograd hooks firing
|
||||||
if len(overlap_info.bucket_indices_seen) > 0:
|
if len(overlap_info.bucket_indices_seen) > 0:
|
||||||
assert (
|
assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, (
|
||||||
overlap_info.bucket_indices_seen[-1] == bucket_index - 1
|
"Bucket indices are not in incremental order"
|
||||||
), "Bucket indices are not in incremental order"
|
)
|
||||||
else:
|
else:
|
||||||
assert bucket_index == 0, "Bucket indices do not start from 0"
|
assert bucket_index == 0, "Bucket indices do not start from 0"
|
||||||
overlap_info.bucket_indices_seen.append(bucket_index)
|
overlap_info.bucket_indices_seen.append(bucket_index)
|
||||||
|
@ -129,7 +129,7 @@ def bf16_compress_hook(
|
|||||||
|
|
||||||
|
|
||||||
def fp16_compress_wrapper(
|
def fp16_compress_wrapper(
|
||||||
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
|
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
|
||||||
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
|
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
|
Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
|
||||||
@ -167,7 +167,7 @@ def fp16_compress_wrapper(
|
|||||||
|
|
||||||
|
|
||||||
def bf16_compress_wrapper(
|
def bf16_compress_wrapper(
|
||||||
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
|
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
|
||||||
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
|
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
|
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
|
||||||
|
@ -223,8 +223,7 @@ class Join:
|
|||||||
self._rank = dist.get_rank(self._process_group)
|
self._rank = dist.get_rank(self._process_group)
|
||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self): ...
|
||||||
...
|
|
||||||
|
|
||||||
def __exit__(
|
def __exit__(
|
||||||
self,
|
self,
|
||||||
|
@ -52,7 +52,10 @@ def average_parameters(
|
|||||||
|
|
||||||
|
|
||||||
def get_params_to_average(
|
def get_params_to_average(
|
||||||
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]]
|
params: Union[
|
||||||
|
Iterable[torch.nn.Parameter],
|
||||||
|
Iterable[dict[str, torch.nn.Parameter]],
|
||||||
|
],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Return a list of parameters that need to average.
|
Return a list of parameters that need to average.
|
||||||
|
@ -550,9 +550,7 @@ def create_default_global_save_plan(
|
|||||||
new_item = dataclasses.replace(item, index=new_index)
|
new_item = dataclasses.replace(item, index=new_index)
|
||||||
new_items.append(new_item)
|
new_items.append(new_item)
|
||||||
|
|
||||||
assert (
|
assert item.tensor_data.chunk is not None, f"""
|
||||||
item.tensor_data.chunk is not None
|
|
||||||
), f"""
|
|
||||||
Cannot create MD for tensor without bounds.
|
Cannot create MD for tensor without bounds.
|
||||||
FQN: {item.index.fqn}
|
FQN: {item.index.fqn}
|
||||||
"""
|
"""
|
||||||
|
@ -414,41 +414,33 @@ class FileSystemBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_stream(
|
def create_stream(
|
||||||
self, path: Union[str, os.PathLike], mode: str
|
self, path: Union[str, os.PathLike], mode: str
|
||||||
) -> Generator[io.IOBase, None, None]:
|
) -> Generator[io.IOBase, None, None]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def concat_path(
|
def concat_path(
|
||||||
self, path: Union[str, os.PathLike], suffix: str
|
self, path: Union[str, os.PathLike], suffix: str
|
||||||
) -> Union[str, os.PathLike]:
|
) -> Union[str, os.PathLike]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def rename(
|
def rename(
|
||||||
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
|
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
|
||||||
) -> None:
|
) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
|
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mkdir(self, path: Union[str, os.PathLike]) -> None:
|
def mkdir(self, path: Union[str, os.PathLike]) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def exists(self, path: Union[str, os.PathLike]) -> bool:
|
def exists(self, path: Union[str, os.PathLike]) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def rm_file(self, path: Union[str, os.PathLike]) -> None:
|
def rm_file(self, path: Union[str, os.PathLike]) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class FileSystem(FileSystemBase):
|
class FileSystem(FileSystemBase):
|
||||||
@ -512,7 +504,6 @@ class FileSystem(FileSystemBase):
|
|||||||
|
|
||||||
|
|
||||||
class _FileSystemWriter(StorageWriter):
|
class _FileSystemWriter(StorageWriter):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Basic implementation of StorageWriter using file IO.
|
Basic implementation of StorageWriter using file IO.
|
||||||
|
|
||||||
@ -800,9 +791,9 @@ class FileSystemReader(StorageReader):
|
|||||||
)
|
)
|
||||||
target_tensor = planner.resolve_tensor(req).detach()
|
target_tensor = planner.resolve_tensor(req).detach()
|
||||||
|
|
||||||
assert (
|
assert target_tensor.size() == tensor.size(), (
|
||||||
target_tensor.size() == tensor.size()
|
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||||
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
)
|
||||||
target_tensor.copy_(tensor)
|
target_tensor.copy_(tensor)
|
||||||
planner.commit_tensor(req, target_tensor)
|
planner.commit_tensor(req, target_tensor)
|
||||||
|
|
||||||
|
@ -135,12 +135,12 @@ def _get_state_dict_2d_layout(
|
|||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
specs[key] = (None, value.size())
|
specs[key] = (None, value.size())
|
||||||
if _is_nested_tensor(value):
|
if _is_nested_tensor(value):
|
||||||
assert (
|
assert len(value.local_shards()) == 1, (
|
||||||
len(value.local_shards()) == 1
|
"Cannot handle ST with multiple shards"
|
||||||
), "Cannot handle ST with multiple shards"
|
)
|
||||||
assert isinstance(
|
assert isinstance(value, ShardedTensor), (
|
||||||
value, ShardedTensor
|
"Can only handle nested ShardedTensor"
|
||||||
), "Can only handle nested ShardedTensor"
|
)
|
||||||
shard = value.local_shards()[0]
|
shard = value.local_shards()[0]
|
||||||
specs[key] = (
|
specs[key] = (
|
||||||
shard.metadata.shard_offsets,
|
shard.metadata.shard_offsets,
|
||||||
|
@ -151,7 +151,7 @@ class SavePlanner(abc.ABC):
|
|||||||
>>> storage_meta: Optional[StorageMeta],
|
>>> storage_meta: Optional[StorageMeta],
|
||||||
>>> is_coordinator: bool,
|
>>> is_coordinator: bool,
|
||||||
>>> ) -> None:
|
>>> ) -> None:
|
||||||
>>> # prefix all keys with `foo_``
|
>>> # prefix all keys with `foo_``
|
||||||
>>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
|
>>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
|
||||||
|
|
||||||
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
|
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
|
||||||
@ -175,8 +175,8 @@ class SavePlanner(abc.ABC):
|
|||||||
>>> from itertools import zip_longest
|
>>> from itertools import zip_longest
|
||||||
>>> from dataclasses import replace
|
>>> from dataclasses import replace
|
||||||
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
|
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
|
||||||
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
|
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
|
||||||
>>> # This sample doesn't handle ShardedTensors
|
>>> # This sample doesn't handle ShardedTensors
|
||||||
>>> def create_global_plan(self, all_plans):
|
>>> def create_global_plan(self, all_plans):
|
||||||
>>> iters = [iter(all_plans[0].items)] * len(all_plans)
|
>>> iters = [iter(all_plans[0].items)] * len(all_plans)
|
||||||
>>> items_per_rank = [
|
>>> items_per_rank = [
|
||||||
@ -347,7 +347,7 @@ class LoadPlanner:
|
|||||||
>>> self.is_coordinator = is_coordinator
|
>>> self.is_coordinator = is_coordinator
|
||||||
>>>
|
>>>
|
||||||
>>> def load_bytes(self, read_item, value):
|
>>> def load_bytes(self, read_item, value):
|
||||||
>>> # Remove the "foo_" prefix
|
>>> # Remove the "foo_" prefix
|
||||||
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
|
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,10 +140,12 @@ class StateDictOptions:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class _StateDictInfo(StateDictOptions):
|
class _StateDictInfo(StateDictOptions):
|
||||||
fqn_param_mapping: dict[
|
fqn_param_mapping: dict[
|
||||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
Union[str, torch.Tensor],
|
||||||
|
Union[FQNS_T, torch.Tensor],
|
||||||
] = field(default_factory=dict)
|
] = field(default_factory=dict)
|
||||||
shared_params_mapping: dict[
|
shared_params_mapping: dict[
|
||||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
Union[str, torch.Tensor],
|
||||||
|
Union[FQNS_T, torch.Tensor],
|
||||||
] = field(default_factory=dict)
|
] = field(default_factory=dict)
|
||||||
submodule_prefixes: set[str] = field(default_factory=set)
|
submodule_prefixes: set[str] = field(default_factory=set)
|
||||||
handle_model: bool = True
|
handle_model: bool = True
|
||||||
@ -1140,7 +1142,9 @@ def get_state_dict(
|
|||||||
|
|
||||||
|
|
||||||
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
|
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
|
||||||
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
|
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
|
||||||
|
... fsdp_model, fsdp_optim
|
||||||
|
... )
|
||||||
|
|
||||||
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
|
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
|
||||||
>>> # the asserts will fail.
|
>>> # the asserts will fail.
|
||||||
|
@ -125,7 +125,9 @@ def load(
|
|||||||
>>> my_model = MyModule()
|
>>> my_model = MyModule()
|
||||||
>>> optimizer = Adagrad(my_model.parameters())
|
>>> optimizer = Adagrad(my_model.parameters())
|
||||||
>>> model_state_dict = my_model.state_dict()
|
>>> model_state_dict = my_model.state_dict()
|
||||||
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
|
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
|
||||||
|
... "/checkpoint/1"
|
||||||
|
... )
|
||||||
|
|
||||||
>>> torch.distributed.checkpoint.load_state_dict(
|
>>> torch.distributed.checkpoint.load_state_dict(
|
||||||
>>> state_dict=model_state_dict,
|
>>> state_dict=model_state_dict,
|
||||||
|
@ -127,7 +127,9 @@ def save(
|
|||||||
|
|
||||||
>>> state_dict = {"model": my_model}
|
>>> state_dict = {"model": my_model}
|
||||||
|
|
||||||
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
|
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
|
||||||
|
... "/checkpoint/1"
|
||||||
|
... )
|
||||||
>>> torch.distributed.checkpoint.save(
|
>>> torch.distributed.checkpoint.save(
|
||||||
>>> state_dict=state_dict,
|
>>> state_dict=state_dict,
|
||||||
>>> storage_writer=fs_storage_writer,
|
>>> storage_writer=fs_storage_writer,
|
||||||
@ -206,7 +208,9 @@ def async_save(
|
|||||||
|
|
||||||
>>> state_dict = {"model": my_model}
|
>>> state_dict = {"model": my_model}
|
||||||
|
|
||||||
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
|
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
|
||||||
|
... "/checkpoint/1"
|
||||||
|
... )
|
||||||
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
|
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
|
||||||
>>> state_dict=state_dict,
|
>>> state_dict=state_dict,
|
||||||
>>> storage_writer=fs_storage_writer,
|
>>> storage_writer=fs_storage_writer,
|
||||||
@ -223,7 +227,9 @@ def async_save(
|
|||||||
pg = process_group or _get_default_group()
|
pg = process_group or _get_default_group()
|
||||||
assert (
|
assert (
|
||||||
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
|
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
|
||||||
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
), (
|
||||||
|
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||||
|
)
|
||||||
|
|
||||||
storage_writer = cast(
|
storage_writer = cast(
|
||||||
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
|
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
|
||||||
|
@ -32,7 +32,7 @@ R = TypeVar("R")
|
|||||||
|
|
||||||
|
|
||||||
def _get_failure_dict(
|
def _get_failure_dict(
|
||||||
results: list[Union[T, WRAPPED_EXCEPTION]]
|
results: list[Union[T, WRAPPED_EXCEPTION]],
|
||||||
) -> dict[int, WRAPPED_EXCEPTION]:
|
) -> dict[int, WRAPPED_EXCEPTION]:
|
||||||
return cast(
|
return cast(
|
||||||
dict[int, WRAPPED_EXCEPTION],
|
dict[int, WRAPPED_EXCEPTION],
|
||||||
|
@ -221,8 +221,12 @@ else:
|
|||||||
if cur_rank in mesh_nd:
|
if cur_rank in mesh_nd:
|
||||||
res_flattened_mesh = flattened_mesh
|
res_flattened_mesh = flattened_mesh
|
||||||
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
|
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
|
||||||
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined]
|
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
|
||||||
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined]
|
res_flattened_mesh # type: ignore[possibly-undefined]
|
||||||
|
)
|
||||||
|
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
|
||||||
|
flatten_dims_in_root
|
||||||
|
) # type: ignore[possibly-undefined]
|
||||||
|
|
||||||
return res_flattened_mesh
|
return res_flattened_mesh
|
||||||
|
|
||||||
@ -242,9 +246,9 @@ else:
|
|||||||
root_mesh = self.get_root_mesh(device_mesh)
|
root_mesh = self.get_root_mesh(device_mesh)
|
||||||
child_mesh_dim_names = device_mesh.mesh_dim_names
|
child_mesh_dim_names = device_mesh.mesh_dim_names
|
||||||
if root_mesh and child_mesh_dim_names:
|
if root_mesh and child_mesh_dim_names:
|
||||||
assert (
|
assert len(child_mesh_dim_names) == 1, (
|
||||||
len(child_mesh_dim_names) == 1
|
"The submesh can only be a 1D mesh."
|
||||||
), "The submesh can only be a 1D mesh."
|
)
|
||||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||||
return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name)
|
return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name)
|
||||||
return None
|
return None
|
||||||
@ -763,7 +767,9 @@ else:
|
|||||||
root_mesh, None
|
root_mesh, None
|
||||||
)
|
)
|
||||||
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
|
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
|
||||||
dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index]
|
dim_group_infos = root_to_flatten_mapping[
|
||||||
|
mesh_dim # type: ignore[index]
|
||||||
|
]._dim_group_infos[0][:2]
|
||||||
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
|
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
|
||||||
else:
|
else:
|
||||||
mesh_dim = (
|
mesh_dim = (
|
||||||
@ -905,9 +911,9 @@ else:
|
|||||||
mesh_dim = 0
|
mesh_dim = 0
|
||||||
|
|
||||||
mesh_dim_group = not_none(self.get_group(mesh_dim))
|
mesh_dim_group = not_none(self.get_group(mesh_dim))
|
||||||
assert isinstance(
|
assert isinstance(mesh_dim_group, ProcessGroup), (
|
||||||
mesh_dim_group, ProcessGroup
|
"We expect ProcessGroup before calling `get_rank`!"
|
||||||
), "We expect ProcessGroup before calling `get_rank`!"
|
)
|
||||||
return not_none(get_rank(mesh_dim_group))
|
return not_none(get_rank(mesh_dim_group))
|
||||||
|
|
||||||
def get_coordinate(self) -> Optional[list[int]]:
|
def get_coordinate(self) -> Optional[list[int]]:
|
||||||
|
@ -334,12 +334,12 @@ class Backend(str): # noqa: SLOT000
|
|||||||
# Allow UCC plugin if Pytorch is not built with native support.
|
# Allow UCC plugin if Pytorch is not built with native support.
|
||||||
# TODO: remove this exception once UCC plugin is fully deprecated.
|
# TODO: remove this exception once UCC plugin is fully deprecated.
|
||||||
if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()):
|
if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()):
|
||||||
assert not hasattr(
|
assert not hasattr(Backend, name.upper()), (
|
||||||
Backend, name.upper()
|
f"{name.upper()} c10d backend already exist"
|
||||||
), f"{name.upper()} c10d backend already exist"
|
)
|
||||||
assert (
|
assert name.upper() not in Backend._plugins, (
|
||||||
name.upper() not in Backend._plugins
|
f"{name.upper()} c10d backend creator function already exist"
|
||||||
), f"{name.upper()} c10d backend creator function already exist"
|
)
|
||||||
|
|
||||||
setattr(Backend, name.upper(), name.lower())
|
setattr(Backend, name.upper(), name.lower())
|
||||||
Backend.backend_list.append(name.lower())
|
Backend.backend_list.append(name.lower())
|
||||||
@ -1650,9 +1650,9 @@ def init_process_group(
|
|||||||
if "torch._dynamo" in sys.modules:
|
if "torch._dynamo" in sys.modules:
|
||||||
torch._dynamo.trace_rules.clear_lru_cache()
|
torch._dynamo.trace_rules.clear_lru_cache()
|
||||||
|
|
||||||
assert (store is None) or (
|
assert (store is None) or (init_method is None), (
|
||||||
init_method is None
|
"Cannot specify both init_method and store."
|
||||||
), "Cannot specify both init_method and store."
|
)
|
||||||
|
|
||||||
if store is not None:
|
if store is not None:
|
||||||
assert world_size > 0, "world_size must be positive if using store"
|
assert world_size > 0, "world_size must be positive if using store"
|
||||||
@ -1734,7 +1734,10 @@ def init_process_group(
|
|||||||
)
|
)
|
||||||
_update_default_pg(default_pg)
|
_update_default_pg(default_pg)
|
||||||
|
|
||||||
_world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
|
_world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index]
|
||||||
|
i: i
|
||||||
|
for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined]
|
||||||
|
}
|
||||||
_backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
|
_backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
|
||||||
_default_pg_init_method = init_method
|
_default_pg_init_method = init_method
|
||||||
|
|
||||||
@ -1959,9 +1962,9 @@ def _new_process_group_helper(
|
|||||||
if not is_nccl_available():
|
if not is_nccl_available():
|
||||||
raise RuntimeError("Distributed package doesn't have NCCL built in")
|
raise RuntimeError("Distributed package doesn't have NCCL built in")
|
||||||
if backend_options is not None:
|
if backend_options is not None:
|
||||||
assert isinstance(
|
assert isinstance(backend_options, ProcessGroupNCCL.Options), (
|
||||||
backend_options, ProcessGroupNCCL.Options
|
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||||
), "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
)
|
||||||
if backend_options._timeout != timeout:
|
if backend_options._timeout != timeout:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"backend_options._timeout was specified, "
|
"backend_options._timeout was specified, "
|
||||||
@ -2001,9 +2004,9 @@ def _new_process_group_helper(
|
|||||||
)
|
)
|
||||||
backend_type = ProcessGroup.BackendType.XCCL
|
backend_type = ProcessGroup.BackendType.XCCL
|
||||||
else:
|
else:
|
||||||
assert (
|
assert backend_str.upper() in Backend._plugins, (
|
||||||
backend_str.upper() in Backend._plugins
|
f"Unknown c10d backend type {backend_str.upper()}"
|
||||||
), f"Unknown c10d backend type {backend_str.upper()}"
|
)
|
||||||
|
|
||||||
backend_plugin = Backend._plugins[backend_str.upper()]
|
backend_plugin = Backend._plugins[backend_str.upper()]
|
||||||
creator_fn = backend_plugin.creator_fn
|
creator_fn = backend_plugin.creator_fn
|
||||||
@ -2630,8 +2633,10 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
|
|||||||
>>> # xdoctest: +SKIP("no rank")
|
>>> # xdoctest: +SKIP("no rank")
|
||||||
>>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
|
>>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
|
||||||
>>> recv_tensor = torch.randn(2, dtype=torch.float32)
|
>>> recv_tensor = torch.randn(2, dtype=torch.float32)
|
||||||
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
|
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
|
||||||
>>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size)
|
>>> recv_op = dist.P2POp(
|
||||||
|
... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
|
||||||
|
... )
|
||||||
>>> reqs = batch_isend_irecv([send_op, recv_op])
|
>>> reqs = batch_isend_irecv([send_op, recv_op])
|
||||||
>>> for req in reqs:
|
>>> for req in reqs:
|
||||||
>>> req.wait()
|
>>> req.wait()
|
||||||
@ -2758,7 +2763,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
|||||||
>>> # xdoctest: +SKIP("no rank")
|
>>> # xdoctest: +SKIP("no rank")
|
||||||
>>> # All tensors below are of torch.int64 type.
|
>>> # All tensors below are of torch.int64 type.
|
||||||
>>> # We have 2 process groups, 2 ranks.
|
>>> # We have 2 process groups, 2 ranks.
|
||||||
>>> device = torch.device(f'cuda:{rank}')
|
>>> device = torch.device(f"cuda:{rank}")
|
||||||
>>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
|
>>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
|
||||||
>>> tensor
|
>>> tensor
|
||||||
tensor([1, 2], device='cuda:0') # Rank 0
|
tensor([1, 2], device='cuda:0') # Rank 0
|
||||||
@ -2770,7 +2775,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
|||||||
|
|
||||||
>>> # All tensors below are of torch.cfloat type.
|
>>> # All tensors below are of torch.cfloat type.
|
||||||
>>> # We have 2 process groups, 2 ranks.
|
>>> # We have 2 process groups, 2 ranks.
|
||||||
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
|
>>> tensor = torch.tensor(
|
||||||
|
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
|
||||||
|
... ) + 2 * rank * (1 + 1j)
|
||||||
>>> tensor
|
>>> tensor
|
||||||
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
|
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
|
||||||
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
|
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
|
||||||
@ -3380,9 +3387,9 @@ def recv_object_list(
|
|||||||
)
|
)
|
||||||
|
|
||||||
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
|
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
|
||||||
assert (
|
assert rank_sizes == rank_objects, (
|
||||||
rank_sizes == rank_objects
|
"Mismatch in return ranks for object sizes and objects."
|
||||||
), "Mismatch in return ranks for object sizes and objects."
|
)
|
||||||
# Deserialize objects using their stored sizes.
|
# Deserialize objects using their stored sizes.
|
||||||
offset = 0
|
offset = 0
|
||||||
for i, obj_size in enumerate(object_sizes_tensor):
|
for i, obj_size in enumerate(object_sizes_tensor):
|
||||||
@ -3673,8 +3680,10 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
|
|||||||
>>> # xdoctest: +SKIP("need process group init")
|
>>> # xdoctest: +SKIP("need process group init")
|
||||||
>>> # All tensors below are of torch.int64 dtype.
|
>>> # All tensors below are of torch.int64 dtype.
|
||||||
>>> # We have 2 process groups, 2 ranks.
|
>>> # We have 2 process groups, 2 ranks.
|
||||||
>>> device = torch.device(f'cuda:{rank}')
|
>>> device = torch.device(f"cuda:{rank}")
|
||||||
>>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)]
|
>>> tensor_list = [
|
||||||
|
... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)
|
||||||
|
... ]
|
||||||
>>> tensor_list
|
>>> tensor_list
|
||||||
[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
|
[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
|
||||||
[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
|
[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
|
||||||
@ -3689,11 +3698,15 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
|
|||||||
|
|
||||||
>>> # All tensors below are of torch.cfloat dtype.
|
>>> # All tensors below are of torch.cfloat dtype.
|
||||||
>>> # We have 2 process groups, 2 ranks.
|
>>> # We have 2 process groups, 2 ranks.
|
||||||
>>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)]
|
>>> tensor_list = [
|
||||||
|
... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)
|
||||||
|
... ]
|
||||||
>>> tensor_list
|
>>> tensor_list
|
||||||
[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
|
[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
|
||||||
[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
|
[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
|
||||||
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
|
>>> tensor = torch.tensor(
|
||||||
|
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
|
||||||
|
... ) + 2 * rank * (1 + 1j)
|
||||||
>>> tensor
|
>>> tensor
|
||||||
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
|
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
|
||||||
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
|
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
|
||||||
@ -3769,7 +3782,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
|
|||||||
>>> # xdoctest: +SKIP("need process group init")
|
>>> # xdoctest: +SKIP("need process group init")
|
||||||
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
|
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
|
||||||
>>> # We have two ranks.
|
>>> # We have two ranks.
|
||||||
>>> device = torch.device(f'cuda:{rank}')
|
>>> device = torch.device(f"cuda:{rank}")
|
||||||
>>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
|
>>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
|
||||||
>>> tensor_in
|
>>> tensor_in
|
||||||
tensor([1, 2], device='cuda:0') # Rank 0
|
tensor([1, 2], device='cuda:0') # Rank 0
|
||||||
@ -3969,8 +3982,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
|||||||
)
|
)
|
||||||
elif gather_list:
|
elif gather_list:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Argument ``gather_list`` must NOT be specified "
|
"Argument ``gather_list`` must NOT be specified on non-destination ranks."
|
||||||
"on non-destination ranks."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -4141,8 +4153,7 @@ def scatter(
|
|||||||
else:
|
else:
|
||||||
if scatter_list:
|
if scatter_list:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Argument ``scatter_list`` must NOT be specified "
|
"Argument ``scatter_list`` must NOT be specified on non-source ranks."
|
||||||
"on non-source ranks."
|
|
||||||
)
|
)
|
||||||
input_tensors = []
|
input_tensors = []
|
||||||
output_tensors = [tensor]
|
output_tensors = [tensor]
|
||||||
@ -4225,7 +4236,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
|
|||||||
>>> # xdoctest: +SKIP("need process group init")
|
>>> # xdoctest: +SKIP("need process group init")
|
||||||
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
|
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
|
||||||
>>> # We have two ranks.
|
>>> # We have two ranks.
|
||||||
>>> device = torch.device(f'cuda:{rank}')
|
>>> device = torch.device(f"cuda:{rank}")
|
||||||
>>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
|
>>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
|
||||||
>>> # Input in concatenation form
|
>>> # Input in concatenation form
|
||||||
>>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
|
>>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
|
||||||
@ -4381,7 +4392,7 @@ def all_to_all_single(
|
|||||||
|
|
||||||
>>> # Essentially, it is similar to following operation:
|
>>> # Essentially, it is similar to following operation:
|
||||||
>>> scatter_list = list(input.chunk(world_size))
|
>>> scatter_list = list(input.chunk(world_size))
|
||||||
>>> gather_list = list(output.chunk(world_size))
|
>>> gather_list = list(output.chunk(world_size))
|
||||||
>>> for i in range(world_size):
|
>>> for i in range(world_size):
|
||||||
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
|
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
|
||||||
|
|
||||||
@ -4411,7 +4422,9 @@ def all_to_all_single(
|
|||||||
|
|
||||||
|
|
||||||
>>> # Another example with tensors of torch.cfloat type.
|
>>> # Another example with tensors of torch.cfloat type.
|
||||||
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
|
>>> input = torch.tensor(
|
||||||
|
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
|
||||||
|
... ) + 4 * rank * (1 + 1j)
|
||||||
>>> input
|
>>> input
|
||||||
tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
|
tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
|
||||||
tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
|
tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
|
||||||
@ -4510,7 +4523,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
|||||||
|
|
||||||
>>> # Essentially, it is similar to following operation:
|
>>> # Essentially, it is similar to following operation:
|
||||||
>>> scatter_list = input
|
>>> scatter_list = input
|
||||||
>>> gather_list = output
|
>>> gather_list = output
|
||||||
>>> for i in range(world_size):
|
>>> for i in range(world_size):
|
||||||
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
|
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
|
||||||
|
|
||||||
@ -4544,7 +4557,9 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
|||||||
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3
|
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3
|
||||||
|
|
||||||
>>> # Another example with tensors of torch.cfloat type.
|
>>> # Another example with tensors of torch.cfloat type.
|
||||||
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
|
>>> input = torch.tensor(
|
||||||
|
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
|
||||||
|
... ) + 4 * rank * (1 + 1j)
|
||||||
>>> input = list(input.chunk(4))
|
>>> input = list(input.chunk(4))
|
||||||
>>> input
|
>>> input
|
||||||
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
|
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
|
||||||
@ -4882,9 +4897,9 @@ def split_group(
|
|||||||
backend_config = BackendConfig(backend)
|
backend_config = BackendConfig(backend)
|
||||||
|
|
||||||
if pg_options is not None:
|
if pg_options is not None:
|
||||||
assert isinstance(
|
assert isinstance(pg_options, ProcessGroupNCCL.Options), (
|
||||||
pg_options, ProcessGroupNCCL.Options
|
"Expected pg_options argument to be of type ProcessGroupNCCL.Options"
|
||||||
), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
|
)
|
||||||
else:
|
else:
|
||||||
# default pg_options same as the parent process group
|
# default pg_options same as the parent process group
|
||||||
pg_options = parent_backend.options
|
pg_options = parent_backend.options
|
||||||
@ -5086,9 +5101,9 @@ def _new_group_with_tag(
|
|||||||
if device_id is None:
|
if device_id is None:
|
||||||
device_id = default_pg.bound_device_id
|
device_id = default_pg.bound_device_id
|
||||||
elif default_pg.bound_device_id is not None:
|
elif default_pg.bound_device_id is not None:
|
||||||
assert (
|
assert device_id == default_pg.bound_device_id, (
|
||||||
device_id == default_pg.bound_device_id
|
"Mismatched bound device between new pg and the default pg."
|
||||||
), "Mismatched bound device between new pg and the default pg."
|
)
|
||||||
default_backend, default_store = _world.pg_map[default_pg]
|
default_backend, default_store = _world.pg_map[default_pg]
|
||||||
global_rank = default_pg.rank()
|
global_rank = default_pg.rank()
|
||||||
global_world_size = default_pg.size()
|
global_world_size = default_pg.size()
|
||||||
@ -5408,9 +5423,9 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
|
|||||||
def _find_or_create_pg_by_ranks_and_tag(
|
def _find_or_create_pg_by_ranks_and_tag(
|
||||||
tag: str, ranks: list[int], stride: int
|
tag: str, ranks: list[int], stride: int
|
||||||
) -> ProcessGroup:
|
) -> ProcessGroup:
|
||||||
assert (
|
assert len(ranks) % stride == 0, (
|
||||||
len(ranks) % stride == 0
|
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||||
), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
)
|
||||||
|
|
||||||
my_rank = get_rank()
|
my_rank = get_rank()
|
||||||
my_ranks = None
|
my_ranks = None
|
||||||
|
@ -40,8 +40,9 @@ def worker_main() -> Generator[None, None, None]:
|
|||||||
def main():
|
def main():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
main()
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with ExitStack() as stack:
|
with ExitStack() as stack:
|
||||||
|
@ -14,7 +14,10 @@ Example of usage:
|
|||||||
::
|
::
|
||||||
|
|
||||||
from torch.distributed.elastic import events
|
from torch.distributed.elastic import events
|
||||||
event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
|
|
||||||
|
event = events.Event(
|
||||||
|
name="test_event", source=events.EventSource.WORKER, metadata={...}
|
||||||
|
)
|
||||||
events.get_logging_handler(destination="console").info(event)
|
events.get_logging_handler(destination="console").info(event)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -52,11 +52,12 @@ The example below measures the latency for the ``calculate()`` function.
|
|||||||
metrics.configure(metrics.NullMetricsHandler())
|
metrics.configure(metrics.NullMetricsHandler())
|
||||||
metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
|
metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
|
||||||
|
|
||||||
|
|
||||||
def my_method():
|
def my_method():
|
||||||
start = time.time()
|
start = time.time()
|
||||||
calculate()
|
calculate()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
metrics.put_metric("calculate_latency", int(end-start), "my_module")
|
metrics.put_metric("calculate_latency", int(end - start), "my_module")
|
||||||
|
|
||||||
You may also use the torch.distributed.elastic.metrics.prof` decorator
|
You may also use the torch.distributed.elastic.metrics.prof` decorator
|
||||||
to conveniently and succinctly profile functions
|
to conveniently and succinctly profile functions
|
||||||
@ -70,15 +71,16 @@ to conveniently and succinctly profile functions
|
|||||||
metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
|
metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
|
||||||
metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
|
metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
|
||||||
|
|
||||||
|
|
||||||
@metrics.prof
|
@metrics.prof
|
||||||
def foo():
|
def foo():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class Bar():
|
|
||||||
|
|
||||||
@metrics.prof
|
class Bar:
|
||||||
def baz():
|
@metrics.prof
|
||||||
pass
|
def baz():
|
||||||
|
pass
|
||||||
|
|
||||||
``@metrics.prof`` will publish the following metrics
|
``@metrics.prof`` will publish the following metrics
|
||||||
::
|
::
|
||||||
@ -102,8 +104,8 @@ console.
|
|||||||
|
|
||||||
import torch.distributed.elastic.metrics as metrics
|
import torch.distributed.elastic.metrics as metrics
|
||||||
|
|
||||||
metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic")
|
metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic")
|
||||||
metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app")
|
metrics.configure(metrics.ConsoleMetricHandler(), group="my_app")
|
||||||
|
|
||||||
**Writing a Custom Metric Handler**:
|
**Writing a Custom Metric Handler**:
|
||||||
|
|
||||||
@ -117,13 +119,15 @@ Below is a toy example that prints the metrics to ``stdout``
|
|||||||
|
|
||||||
import torch.distributed.elastic.metrics as metrics
|
import torch.distributed.elastic.metrics as metrics
|
||||||
|
|
||||||
|
|
||||||
class StdoutMetricHandler(metrics.MetricHandler):
|
class StdoutMetricHandler(metrics.MetricHandler):
|
||||||
def emit(self, metric_data):
|
def emit(self, metric_data):
|
||||||
ts = metric_data.timestamp
|
ts = metric_data.timestamp
|
||||||
group = metric_data.group_name
|
group = metric_data.group_name
|
||||||
name = metric_data.name
|
name = metric_data.name
|
||||||
value = metric_data.value
|
value = metric_data.value
|
||||||
print(f"[{ts}][{group}]: {name}={value}")
|
print(f"[{ts}][{group}]: {name}={value}")
|
||||||
|
|
||||||
|
|
||||||
metrics.configure(StdoutMetricHandler(), group="my_app")
|
metrics.configure(StdoutMetricHandler(), group="my_app")
|
||||||
|
|
||||||
|
@ -123,6 +123,7 @@ def prof(fn=None, group: str = "torchelastic"):
|
|||||||
def x():
|
def x():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@metrics.prof(group="agent")
|
@metrics.prof(group="agent")
|
||||||
def y():
|
def y():
|
||||||
pass
|
pass
|
||||||
|
@ -20,22 +20,23 @@ Usage 1: Launching two trainers as a function
|
|||||||
|
|
||||||
from torch.distributed.elastic.multiprocessing import Std, start_processes
|
from torch.distributed.elastic.multiprocessing import Std, start_processes
|
||||||
|
|
||||||
|
|
||||||
def trainer(a, b, c):
|
def trainer(a, b, c):
|
||||||
pass # train
|
pass # train
|
||||||
|
|
||||||
|
|
||||||
# runs two trainers
|
# runs two trainers
|
||||||
# LOCAL_RANK=0 trainer(1,2,3)
|
# LOCAL_RANK=0 trainer(1,2,3)
|
||||||
# LOCAL_RANK=1 trainer(4,5,6)
|
# LOCAL_RANK=1 trainer(4,5,6)
|
||||||
ctx = start_processes(
|
ctx = start_processes(
|
||||||
name="trainer",
|
name="trainer",
|
||||||
entrypoint=trainer,
|
entrypoint=trainer,
|
||||||
args={0: (1,2,3), 1: (4,5,6)},
|
args={0: (1, 2, 3), 1: (4, 5, 6)},
|
||||||
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
|
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
|
||||||
log_dir="/tmp/foobar",
|
log_dir="/tmp/foobar",
|
||||||
redirects=Std.ALL, # write all worker stdout/stderr to a log file
|
redirects=Std.ALL, # write all worker stdout/stderr to a log file
|
||||||
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
|
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
|
||||||
)
|
)
|
||||||
|
|
||||||
# waits for all copies of trainer to finish
|
# waits for all copies of trainer to finish
|
||||||
ctx.wait()
|
ctx.wait()
|
||||||
|
@ -165,9 +165,11 @@ def to_map(
|
|||||||
Example:
|
Example:
|
||||||
::
|
::
|
||||||
|
|
||||||
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
|
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
|
||||||
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
|
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
|
||||||
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
|
to_map(
|
||||||
|
{0: Std.OUT, 1: Std.OUT}, local_world_size=2
|
||||||
|
) # returns: {0: Std.OUT, 1: Std.OUT}
|
||||||
"""
|
"""
|
||||||
if isinstance(val_or_map, Std):
|
if isinstance(val_or_map, Std):
|
||||||
return dict.fromkeys(range(local_world_size), val_or_map)
|
return dict.fromkeys(range(local_world_size), val_or_map)
|
||||||
@ -304,7 +306,9 @@ class DefaultLogsSpecs(LogsSpecs):
|
|||||||
if not self._run_log_dir:
|
if not self._run_log_dir:
|
||||||
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
|
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
|
||||||
|
|
||||||
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload]
|
attempt_log_dir = os.path.join(
|
||||||
|
self._run_log_dir, f"attempt_{restart_count}"
|
||||||
|
) # type: ignore[call-overload]
|
||||||
shutil.rmtree(attempt_log_dir, ignore_errors=True)
|
shutil.rmtree(attempt_log_dir, ignore_errors=True)
|
||||||
os.makedirs(attempt_log_dir)
|
os.makedirs(attempt_log_dir)
|
||||||
|
|
||||||
@ -868,9 +872,7 @@ class SubprocessContext(PContext):
|
|||||||
if result.is_failed():
|
if result.is_failed():
|
||||||
first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
|
first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
|
||||||
logger.error(
|
logger.error(
|
||||||
"failed (exitcode: %s)"
|
"failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s",
|
||||||
" local_rank: %s (pid: %s)"
|
|
||||||
" of binary: %s",
|
|
||||||
first_failure.exitcode,
|
first_failure.exitcode,
|
||||||
first_failure.local_rank,
|
first_failure.local_rank,
|
||||||
first_failure.pid,
|
first_failure.pid,
|
||||||
|
@ -318,14 +318,14 @@ def record(
|
|||||||
error_handler = get_error_handler()
|
error_handler = get_error_handler()
|
||||||
error_handler.initialize()
|
error_handler.initialize()
|
||||||
try:
|
try:
|
||||||
foobar()
|
foobar()
|
||||||
except ChildFailedError as e:
|
except ChildFailedError as e:
|
||||||
_, failure = e.get_first_failure()
|
_, failure = e.get_first_failure()
|
||||||
error_handler.dump_error_file(failure.error_file, failure.exitcode)
|
error_handler.dump_error_file(failure.error_file, failure.exitcode)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_handler.record_exception(e)
|
error_handler.record_exception(e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
.. important:: use this decorator once per process at the top level method,
|
.. important:: use this decorator once per process at the top level method,
|
||||||
typically this is the main method.
|
typically this is the main method.
|
||||||
@ -338,8 +338,9 @@ def record(
|
|||||||
def main():
|
def main():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
main()
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not error_handler:
|
if not error_handler:
|
||||||
|
@ -120,11 +120,7 @@ of the following implementations that come with PyTorch:
|
|||||||
backend = C10dRendezvousBackend(store, "my_run_id")
|
backend = C10dRendezvousBackend(store, "my_run_id")
|
||||||
|
|
||||||
rdzv_handler = DynamicRendezvousHandler.from_backend(
|
rdzv_handler = DynamicRendezvousHandler.from_backend(
|
||||||
run_id="my_run_id",
|
run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4
|
||||||
store=store,
|
|
||||||
backend=backend,
|
|
||||||
min_nodes=2,
|
|
||||||
max_nodes=4
|
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -89,8 +89,14 @@ class RendezvousStoreInfo:
|
|||||||
addr = local_addr or socket.getfqdn()
|
addr = local_addr or socket.getfqdn()
|
||||||
# When TCPStore is not shared, we fallback to get_free_port.
|
# When TCPStore is not shared, we fallback to get_free_port.
|
||||||
port = server_port or get_free_port()
|
port = server_port or get_free_port()
|
||||||
store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type]
|
store.set(
|
||||||
store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type]
|
RendezvousStoreInfo.MASTER_ADDR_KEY,
|
||||||
|
addr.encode(encoding="UTF-8"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
store.set(
|
||||||
|
RendezvousStoreInfo.MASTER_PORT_KEY,
|
||||||
|
str(port).encode(encoding="UTF-8"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
|
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
|
||||||
port = int(
|
port = int(
|
||||||
|
@ -413,9 +413,9 @@ class EtcdRendezvous:
|
|||||||
active_version = self.wait_for_peers(expected_version)
|
active_version = self.wait_for_peers(expected_version)
|
||||||
state = json.loads(active_version.value)
|
state = json.loads(active_version.value)
|
||||||
|
|
||||||
assert (
|
assert state["version"] == expected_version, (
|
||||||
state["version"] == expected_version
|
"Logic error: failed to observe version mismatch"
|
||||||
), "Logic error: failed to observe version mismatch"
|
)
|
||||||
|
|
||||||
return self.confirm_phase(expected_version, this_rank)
|
return self.confirm_phase(expected_version, this_rank)
|
||||||
|
|
||||||
@ -533,9 +533,9 @@ class EtcdRendezvous:
|
|||||||
"Rendezvous version changed. Must try join the new one."
|
"Rendezvous version changed. Must try join the new one."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert len(state["participants"]) < self._num_max_workers, (
|
||||||
len(state["participants"]) < self._num_max_workers
|
"Logic error: joinable rendezvous should always have space left"
|
||||||
), "Logic error: joinable rendezvous should always have space left"
|
)
|
||||||
|
|
||||||
this_rank = len(state["participants"])
|
this_rank = len(state["participants"])
|
||||||
state["participants"].append(this_rank)
|
state["participants"].append(this_rank)
|
||||||
|
@ -86,11 +86,15 @@ def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
|
|||||||
from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
|
from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
|
||||||
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
|
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
|
||||||
|
|
||||||
|
|
||||||
def create_my_rdzv(params: RendezvousParameters):
|
def create_my_rdzv(params: RendezvousParameters):
|
||||||
return MyCustomRdzv(params)
|
return MyCustomRdzv(params)
|
||||||
|
|
||||||
|
|
||||||
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
|
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
|
||||||
|
|
||||||
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
|
my_rdzv_handler = get_rendezvous_handler(
|
||||||
|
"my_rdzv_backend_name", RendezvousParameters
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
return handler_registry.create_handler(params)
|
return handler_registry.create_handler(params)
|
||||||
|
@ -57,10 +57,10 @@ def get_all(store, rank: int, prefix: str, world_size: int):
|
|||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
values = get_all(store, 'torchelastic/data', 3)
|
values = get_all(store, "torchelastic/data", 3)
|
||||||
value1 = values[0] # retrieves the data for key torchelastic/data0
|
value1 = values[0] # retrieves the data for key torchelastic/data0
|
||||||
value2 = values[1] # retrieves the data for key torchelastic/data1
|
value2 = values[1] # retrieves the data for key torchelastic/data1
|
||||||
value3 = values[2] # retrieves the data for key torchelastic/data2
|
value3 = values[2] # retrieves the data for key torchelastic/data2
|
||||||
|
|
||||||
"""
|
"""
|
||||||
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])
|
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
"""
|
"""
|
||||||
This file includes private common utilities for FSDP.
|
This file includes private common utilities for FSDP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
@ -200,9 +201,9 @@ def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamH
|
|||||||
# handles, meaning no entry in `_fully_sharded_module_to_handles`
|
# handles, meaning no entry in `_fully_sharded_module_to_handles`
|
||||||
if state._handle is None:
|
if state._handle is None:
|
||||||
return None
|
return None
|
||||||
assert (
|
assert module in state._fully_sharded_module_to_handle, (
|
||||||
module in state._fully_sharded_module_to_handle
|
f"Expects a fully sharded module but got {module} on rank {state.rank}"
|
||||||
), f"Expects a fully sharded module but got {module} on rank {state.rank}"
|
)
|
||||||
return state._fully_sharded_module_to_handle[module]
|
return state._fully_sharded_module_to_handle[module]
|
||||||
else:
|
else:
|
||||||
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
|
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
|
||||||
@ -255,9 +256,9 @@ def _named_parameters_with_duplicates(
|
|||||||
This API is required as some modules overwrite `named_parameters()` but do not support
|
This API is required as some modules overwrite `named_parameters()` but do not support
|
||||||
`remove_duplicate`.
|
`remove_duplicate`.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert "remove_duplicate" not in kwargs, (
|
||||||
"remove_duplicate" not in kwargs
|
"_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
|
||||||
), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
|
)
|
||||||
kwargs["remove_duplicate"] = False
|
kwargs["remove_duplicate"] = False
|
||||||
try:
|
try:
|
||||||
ret = list(module.named_parameters(**kwargs))
|
ret = list(module.named_parameters(**kwargs))
|
||||||
|
@ -190,9 +190,9 @@ class _ExecOrderData:
|
|||||||
return
|
return
|
||||||
if self.is_first_iter:
|
if self.is_first_iter:
|
||||||
msg_prefix = "Forward order differs across ranks:"
|
msg_prefix = "Forward order differs across ranks:"
|
||||||
optional_local_indices: tuple[
|
optional_local_indices: tuple[Optional[int], ...] = (
|
||||||
Optional[int], ...
|
self._get_handle_indices(handle)
|
||||||
] = self._get_handle_indices(handle)
|
)
|
||||||
device = handle.device # guaranteed to be non-CPU
|
device = handle.device # guaranteed to be non-CPU
|
||||||
num_valid_indices = sum(
|
num_valid_indices = sum(
|
||||||
(index is not None) for index in optional_local_indices
|
(index is not None) for index in optional_local_indices
|
||||||
@ -250,8 +250,7 @@ class _ExecOrderData:
|
|||||||
(
|
(
|
||||||
rank,
|
rank,
|
||||||
world_indices[
|
world_indices[
|
||||||
rank
|
rank * num_valid_indices : (rank + 1)
|
||||||
* num_valid_indices : (rank + 1)
|
|
||||||
* num_valid_indices
|
* num_valid_indices
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -586,7 +586,10 @@ class FlatParamHandle:
|
|||||||
)
|
)
|
||||||
self._fsdp_extension = fsdp_extension
|
self._fsdp_extension = fsdp_extension
|
||||||
self._init_flat_param_and_metadata(
|
self._init_flat_param_and_metadata(
|
||||||
params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
|
params,
|
||||||
|
fully_sharded_module,
|
||||||
|
self._aligned_numel,
|
||||||
|
use_orig_params, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
self._use_unsharded_views(as_params=False)
|
self._use_unsharded_views(as_params=False)
|
||||||
|
|
||||||
@ -978,9 +981,9 @@ class FlatParamHandle:
|
|||||||
shard_param_infos = self._get_shard_metadata(
|
shard_param_infos = self._get_shard_metadata(
|
||||||
unsharded_start_idx, unsharded_end_idx
|
unsharded_start_idx, unsharded_end_idx
|
||||||
)
|
)
|
||||||
assert (
|
assert len(shard_param_infos) == flat_param._num_params, (
|
||||||
len(shard_param_infos) == flat_param._num_params
|
f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
|
||||||
), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
|
)
|
||||||
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
|
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
|
||||||
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
|
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
|
||||||
|
|
||||||
@ -996,9 +999,9 @@ class FlatParamHandle:
|
|||||||
unsharded flat parameter specifying the shard.
|
unsharded flat parameter specifying the shard.
|
||||||
"""
|
"""
|
||||||
flat_param_offsets = self._get_flat_param_offsets()
|
flat_param_offsets = self._get_flat_param_offsets()
|
||||||
assert len(flat_param_offsets) == len(
|
assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), (
|
||||||
self.flat_param._numels_with_padding
|
f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
||||||
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
)
|
||||||
shard_param_infos: list[_ShardParamInfo] = []
|
shard_param_infos: list[_ShardParamInfo] = []
|
||||||
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
|
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
|
||||||
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
|
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
|
||||||
@ -1075,9 +1078,9 @@ class FlatParamHandle:
|
|||||||
else:
|
else:
|
||||||
chunk = chunks[rank]
|
chunk = chunks[rank]
|
||||||
numel_to_pad = chunks[0].numel() - chunk.numel()
|
numel_to_pad = chunks[0].numel() - chunk.numel()
|
||||||
assert (
|
assert numel_to_pad >= 0, (
|
||||||
numel_to_pad >= 0
|
"Chunk's size should be at most the first chunk's size"
|
||||||
), "Chunk's size should be at most the first chunk's size"
|
)
|
||||||
return chunk, numel_to_pad
|
return chunk, numel_to_pad
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1302,7 +1305,8 @@ class FlatParamHandle:
|
|||||||
self._check_low_precision_shard()
|
self._check_low_precision_shard()
|
||||||
flat_param = self.flat_param
|
flat_param = self.flat_param
|
||||||
_alloc_storage(
|
_alloc_storage(
|
||||||
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
|
flat_param._mp_shard,
|
||||||
|
flat_param._local_shard.size(), # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
# `copy_()` implicitly casts to the low precision
|
# `copy_()` implicitly casts to the low precision
|
||||||
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
|
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
|
||||||
@ -1498,7 +1502,8 @@ class FlatParamHandle:
|
|||||||
# default stream suffices since the default stream waits for the
|
# default stream suffices since the default stream waits for the
|
||||||
# unshard stream.
|
# unshard stream.
|
||||||
_no_dispatch_record_stream(
|
_no_dispatch_record_stream(
|
||||||
self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined]
|
self.flat_param._mp_shard,
|
||||||
|
self._device_handle.current_stream(), # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
|
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
|
||||||
|
|
||||||
@ -1593,8 +1598,7 @@ class FlatParamHandle:
|
|||||||
f"but got {flat_param.grad.device}",
|
f"but got {flat_param.grad.device}",
|
||||||
)
|
)
|
||||||
prev_iter_synced_gradients = (
|
prev_iter_synced_gradients = (
|
||||||
flat_param.grad.size()
|
flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined]
|
||||||
== flat_param._local_shard.size() # type: ignore[attr-defined]
|
|
||||||
)
|
)
|
||||||
if prev_iter_synced_gradients:
|
if prev_iter_synced_gradients:
|
||||||
# TODO (awgu): Gradient accumulation outside `no_sync()`
|
# TODO (awgu): Gradient accumulation outside `no_sync()`
|
||||||
@ -1668,8 +1672,7 @@ class FlatParamHandle:
|
|||||||
cast_grad_to_param_dtype_if_needed(flat_param)
|
cast_grad_to_param_dtype_if_needed(flat_param)
|
||||||
else:
|
else:
|
||||||
_p_assert(
|
_p_assert(
|
||||||
not self.uses_sharded_strategy
|
not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
||||||
or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
|
||||||
"All sharded parameters that received a gradient in the "
|
"All sharded parameters that received a gradient in the "
|
||||||
"post-backward should use `_saved_grad_shard`",
|
"post-backward should use `_saved_grad_shard`",
|
||||||
)
|
)
|
||||||
@ -2504,7 +2507,8 @@ class FlatParamHandle:
|
|||||||
"""Return the FQNs of the parameters present in this rank's shard."""
|
"""Return the FQNs of the parameters present in this rank's shard."""
|
||||||
fqns_in_shard: list[str] = []
|
fqns_in_shard: list[str] = []
|
||||||
for fqn, shard_param_info in zip(
|
for fqn, shard_param_info in zip(
|
||||||
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
|
self.flat_param._fqns,
|
||||||
|
self.flat_param._shard_param_infos, # type: ignore[attr-defined]
|
||||||
):
|
):
|
||||||
if shard_param_info.in_shard:
|
if shard_param_info.in_shard:
|
||||||
fqns_in_shard.append(fqn)
|
fqns_in_shard.append(fqn)
|
||||||
@ -2694,7 +2698,7 @@ def _safe_setattr_tensor_or_param(
|
|||||||
|
|
||||||
|
|
||||||
def _convert_to_params(
|
def _convert_to_params(
|
||||||
tensors: list[Union[torch.Tensor, nn.Parameter]]
|
tensors: list[Union[torch.Tensor, nn.Parameter]],
|
||||||
) -> list[nn.Parameter]:
|
) -> list[nn.Parameter]:
|
||||||
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
|
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
|
||||||
|
|
||||||
|
@ -374,9 +374,9 @@ def foreach_reduce(
|
|||||||
for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)):
|
for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)):
|
||||||
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
|
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
|
||||||
continue
|
continue
|
||||||
assert (
|
assert unsharded_grad.size(shard_dim) % world_size == 0, (
|
||||||
unsharded_grad.size(shard_dim) % world_size == 0
|
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
|
||||||
), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
|
)
|
||||||
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
|
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
|
||||||
unsharded_grads[i] = torch.cat(chunks, dim=0)
|
unsharded_grads[i] = torch.cat(chunks, dim=0)
|
||||||
padded_unsharded_sizes = tuple(
|
padded_unsharded_sizes = tuple(
|
||||||
|
@ -26,9 +26,9 @@ if torch._running_with_deploy():
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
def detect_compiled_autograd():
|
def detect_compiled_autograd():
|
||||||
assert (
|
assert not torch.compiler.is_compiling(), (
|
||||||
not torch.compiler.is_compiling()
|
"`detect_compiled_autograd()` is designed to be called in eager mode"
|
||||||
), "`detect_compiled_autograd()` is designed to be called in eager mode"
|
)
|
||||||
global _compiled_autograd_enabled
|
global _compiled_autograd_enabled
|
||||||
import torch._dynamo.compiled_autograd as ca
|
import torch._dynamo.compiled_autograd as ca
|
||||||
|
|
||||||
|
@ -304,9 +304,9 @@ class FSDPParam:
|
|||||||
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
|
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
|
||||||
)
|
)
|
||||||
split_factor = self._tp_spec.num_shards_map[shard_dim]
|
split_factor = self._tp_spec.num_shards_map[shard_dim]
|
||||||
assert (
|
assert 2 <= self._spmd_mesh.ndim <= 3, (
|
||||||
2 <= self._spmd_mesh.ndim <= 3
|
f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
|
||||||
), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
|
)
|
||||||
self._spmd_placements: tuple[Placement, ...]
|
self._spmd_placements: tuple[Placement, ...]
|
||||||
dp_shard_tp_placement = (
|
dp_shard_tp_placement = (
|
||||||
(
|
(
|
||||||
@ -520,8 +520,9 @@ class FSDPParam:
|
|||||||
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
|
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
|
||||||
if hasattr(self, "_unsharded_param"):
|
if hasattr(self, "_unsharded_param"):
|
||||||
assert compiled_autograd_enabled()
|
assert compiled_autograd_enabled()
|
||||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
with (
|
||||||
self._unsharded_param
|
torch.no_grad(),
|
||||||
|
torch.autograd._unsafe_preserve_version_counter(self._unsharded_param),
|
||||||
):
|
):
|
||||||
# NOTE: Under compile, if an unsharded param goes through
|
# NOTE: Under compile, if an unsharded param goes through
|
||||||
# resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
|
# resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
|
||||||
@ -785,9 +786,9 @@ class FSDPParam:
|
|||||||
assert isinstance(grad, DTensor), f"{type(grad)}"
|
assert isinstance(grad, DTensor), f"{type(grad)}"
|
||||||
placements = self._tp_spec.placements
|
placements = self._tp_spec.placements
|
||||||
if placements != grad.placements:
|
if placements != grad.placements:
|
||||||
assert len(self._tp_spec.placements) == len(
|
assert len(self._tp_spec.placements) == len(grad.placements), (
|
||||||
grad.placements
|
f"{self._tp_spec=} {grad.placements=}"
|
||||||
), f"{self._tp_spec=} {grad.placements=}"
|
)
|
||||||
grad = grad.redistribute(placements=placements)
|
grad = grad.redistribute(placements=placements)
|
||||||
grad = grad._local_tensor
|
grad = grad._local_tensor
|
||||||
return grad
|
return grad
|
||||||
@ -846,9 +847,9 @@ class FSDPParam:
|
|||||||
shard_dim = self.fsdp_placement.dim
|
shard_dim = self.fsdp_placement.dim
|
||||||
length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
|
length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
|
||||||
if local_tensor.size() != padded_sharded_size:
|
if local_tensor.size() != padded_sharded_size:
|
||||||
assert (
|
assert shard_dim == 0, (
|
||||||
shard_dim == 0
|
f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
|
||||||
), f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
|
)
|
||||||
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
|
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
|
||||||
padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(
|
padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(
|
||||||
local_tensor
|
local_tensor
|
||||||
|
@ -424,9 +424,9 @@ class FSDPParamGroup:
|
|||||||
if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
|
if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
|
||||||
# this means the native HSDP is not enabled,
|
# this means the native HSDP is not enabled,
|
||||||
# but user may want to have a custom HSDP setup
|
# but user may want to have a custom HSDP setup
|
||||||
assert (
|
assert self._all_reduce_hook is not None, (
|
||||||
self._all_reduce_hook is not None
|
"all reduce hook stream is specified but hook itself is missing."
|
||||||
), "all reduce hook stream is specified but hook itself is missing."
|
)
|
||||||
all_reduce_stream = self._all_reduce_hook_stream
|
all_reduce_stream = self._all_reduce_hook_stream
|
||||||
else:
|
else:
|
||||||
all_reduce_stream = self.comm_ctx.all_reduce_stream
|
all_reduce_stream = self.comm_ctx.all_reduce_stream
|
||||||
@ -513,9 +513,10 @@ class FSDPParamGroup:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pass type: {pass_type}")
|
raise ValueError(f"Unknown pass type: {pass_type}")
|
||||||
target_fqn = target_fsdp_param_group._module_fqn
|
target_fqn = target_fsdp_param_group._module_fqn
|
||||||
with record_function(
|
with (
|
||||||
f"FSDP::{pass_type}_prefetch for {target_fqn}"
|
record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"),
|
||||||
), target_fsdp_param_group.use_training_state(training_state):
|
target_fsdp_param_group.use_training_state(training_state),
|
||||||
|
):
|
||||||
async_op = target_fsdp_param_group.unshard_async_op
|
async_op = target_fsdp_param_group.unshard_async_op
|
||||||
target_fsdp_param_group.unshard(async_op)
|
target_fsdp_param_group.unshard(async_op)
|
||||||
|
|
||||||
@ -592,9 +593,9 @@ class FSDPParamGroup:
|
|||||||
def _register_state_dict_hooks(self) -> None:
|
def _register_state_dict_hooks(self) -> None:
|
||||||
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
|
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
|
||||||
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
|
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
|
||||||
assert (
|
assert num_pre_save_hooks == num_pre_load_hooks, (
|
||||||
num_pre_save_hooks == num_pre_load_hooks
|
f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
|
||||||
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
|
)
|
||||||
if num_pre_save_hooks > 0:
|
if num_pre_save_hooks > 0:
|
||||||
return # already registered
|
return # already registered
|
||||||
modules_with_fsdp_params: set[nn.Module] = {
|
modules_with_fsdp_params: set[nn.Module] = {
|
||||||
@ -605,12 +606,12 @@ class FSDPParamGroup:
|
|||||||
self._to_sharded()
|
self._to_sharded()
|
||||||
|
|
||||||
for module in modules_with_fsdp_params:
|
for module in modules_with_fsdp_params:
|
||||||
self._module_to_pre_save_state_dict_hook_handle[
|
self._module_to_pre_save_state_dict_hook_handle[module] = (
|
||||||
module
|
module.register_state_dict_pre_hook(to_sharded_hook)
|
||||||
] = module.register_state_dict_pre_hook(to_sharded_hook)
|
)
|
||||||
self._module_to_pre_load_state_dict_hook_handle[
|
self._module_to_pre_load_state_dict_hook_handle[module] = (
|
||||||
module
|
module._register_load_state_dict_pre_hook(to_sharded_hook)
|
||||||
] = module._register_load_state_dict_pre_hook(to_sharded_hook)
|
)
|
||||||
|
|
||||||
# Properties #
|
# Properties #
|
||||||
@property
|
@property
|
||||||
|
@ -60,8 +60,7 @@ def fully_shard(
|
|||||||
mp_policy: MixedPrecisionPolicy = ...,
|
mp_policy: MixedPrecisionPolicy = ...,
|
||||||
offload_policy: OffloadPolicy = ...,
|
offload_policy: OffloadPolicy = ...,
|
||||||
ignored_params: Optional[set[nn.Parameter]] = ...,
|
ignored_params: Optional[set[nn.Parameter]] = ...,
|
||||||
) -> FSDPModule:
|
) -> FSDPModule: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -74,8 +73,7 @@ def fully_shard(
|
|||||||
mp_policy: MixedPrecisionPolicy = ...,
|
mp_policy: MixedPrecisionPolicy = ...,
|
||||||
offload_policy: OffloadPolicy = ...,
|
offload_policy: OffloadPolicy = ...,
|
||||||
ignored_params: Optional[set[nn.Parameter]] = ...,
|
ignored_params: Optional[set[nn.Parameter]] = ...,
|
||||||
) -> list[FSDPModule]:
|
) -> list[FSDPModule]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# The decorator adds a state object to `module` that can be accessed via
|
# The decorator adds a state object to `module` that can be accessed via
|
||||||
|
@ -243,9 +243,9 @@ def _init_inter_node_process_group(
|
|||||||
if local_rank == my_local_rank:
|
if local_rank == my_local_rank:
|
||||||
inter_node_pg = grp
|
inter_node_pg = grp
|
||||||
|
|
||||||
assert (
|
assert inter_node_pg is not None, (
|
||||||
inter_node_pg is not None
|
f"{my_local_rank} expected to assign inter-node pg, but did not"
|
||||||
), f"{my_local_rank} expected to assign inter-node pg, but did not"
|
)
|
||||||
return inter_node_pg
|
return inter_node_pg
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,9 +145,9 @@ def _unflatten_optim_state(
|
|||||||
dict will need to map these entries using the proper unflattened
|
dict will need to map these entries using the proper unflattened
|
||||||
parameter IDs.
|
parameter IDs.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert not shard_state or to_save, (
|
||||||
not shard_state or to_save
|
"If ``shard_state`` is True, ``to_save`` has to be True."
|
||||||
), "If ``shard_state`` is True, ``to_save`` has to be True."
|
)
|
||||||
consolidated_state = _communicate_optim_state(
|
consolidated_state = _communicate_optim_state(
|
||||||
fsdp_param_info,
|
fsdp_param_info,
|
||||||
flat_param_state,
|
flat_param_state,
|
||||||
@ -218,9 +218,9 @@ def _communicate_optim_state(
|
|||||||
):
|
):
|
||||||
tensor_state[state_name] = value
|
tensor_state[state_name] = value
|
||||||
continue
|
continue
|
||||||
assert (
|
assert fsdp_state.compute_device is not None, (
|
||||||
fsdp_state.compute_device is not None
|
"compute_device has not been initialized"
|
||||||
), "compute_device has not been initialized"
|
)
|
||||||
if value.device.type != fsdp_state.compute_device.type:
|
if value.device.type != fsdp_state.compute_device.type:
|
||||||
value = value.to(fsdp_state.compute_device)
|
value = value.to(fsdp_state.compute_device)
|
||||||
# Assume that positive-dimension tensor optimizer state
|
# Assume that positive-dimension tensor optimizer state
|
||||||
@ -394,7 +394,10 @@ def _shard_orig_param_state(
|
|||||||
and value.dim() > 0
|
and value.dim() > 0
|
||||||
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
|
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
|
||||||
):
|
):
|
||||||
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
|
value = value.flatten()[
|
||||||
|
intra_param_start_idx : intra_param_end_idx # type: ignore[operator]
|
||||||
|
+ 1
|
||||||
|
].clone()
|
||||||
new_optim_state[state_name] = value
|
new_optim_state[state_name] = value
|
||||||
return new_optim_state
|
return new_optim_state
|
||||||
|
|
||||||
@ -489,9 +492,9 @@ def _flatten_optim_state_dict(
|
|||||||
if flat_state:
|
if flat_state:
|
||||||
flat_osd_state[key] = flat_state
|
flat_osd_state[key] = flat_state
|
||||||
elif use_orig_params:
|
elif use_orig_params:
|
||||||
assert (
|
assert len(fqns) == 1, (
|
||||||
len(fqns) == 1
|
f"use_orig_params is True but there are multiple FQNs, {fqns}."
|
||||||
), f"use_orig_params is True but there are multiple FQNs, {fqns}."
|
)
|
||||||
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
|
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
|
||||||
state = optim.state.get(param, None) # type: ignore[call-overload]
|
state = optim.state.get(param, None) # type: ignore[call-overload]
|
||||||
if state is not None:
|
if state is not None:
|
||||||
@ -570,14 +573,13 @@ def _flatten_optim_state(
|
|||||||
flat_param = handle.flat_param
|
flat_param = handle.flat_param
|
||||||
num_unflat_params = len(unflat_param_names)
|
num_unflat_params = len(unflat_param_names)
|
||||||
assert num_unflat_params > 0, (
|
assert num_unflat_params > 0, (
|
||||||
"Expects at least one unflattened parameter corresponding to the "
|
"Expects at least one unflattened parameter corresponding to the flat parameter"
|
||||||
"flat parameter"
|
|
||||||
)
|
)
|
||||||
unflat_param_shapes = flat_param._shapes
|
unflat_param_shapes = flat_param._shapes
|
||||||
num_unflat_param_shapes = len(unflat_param_shapes)
|
num_unflat_param_shapes = len(unflat_param_shapes)
|
||||||
assert (
|
assert num_unflat_params == num_unflat_param_shapes, (
|
||||||
num_unflat_params == num_unflat_param_shapes
|
f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
|
||||||
), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
|
)
|
||||||
|
|
||||||
# Check if these unflattened parameters have any optimizer state
|
# Check if these unflattened parameters have any optimizer state
|
||||||
has_state = [
|
has_state = [
|
||||||
@ -759,8 +761,7 @@ def _flatten_tensor_optim_state(
|
|||||||
flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
|
flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
|
||||||
flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined]
|
flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined]
|
||||||
assert flat_tensor.shape == flat_param_shape, (
|
assert flat_tensor.shape == flat_param_shape, (
|
||||||
f"tensor optim state: {flat_tensor.shape} "
|
f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}"
|
||||||
f"flat parameter: {flat_param_shape}"
|
|
||||||
)
|
)
|
||||||
return flat_tensor
|
return flat_tensor
|
||||||
|
|
||||||
@ -1065,9 +1066,9 @@ def _get_param_key_to_param(
|
|||||||
"""
|
"""
|
||||||
clean_fqn_to_curr_fqn: dict[str, str] = {}
|
clean_fqn_to_curr_fqn: dict[str, str] = {}
|
||||||
if is_named_optimizer:
|
if is_named_optimizer:
|
||||||
assert (
|
assert param_to_fqns is not None and flat_param_to_fqn is not None, (
|
||||||
param_to_fqns is not None and flat_param_to_fqn is not None
|
"The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
|
||||||
), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
|
)
|
||||||
assert model is not None
|
assert model is not None
|
||||||
for key, _ in _named_parameters_with_duplicates(model):
|
for key, _ in _named_parameters_with_duplicates(model):
|
||||||
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
|
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
|
||||||
@ -1150,9 +1151,9 @@ def _check_missing_keys_on_rank(
|
|||||||
continue
|
continue
|
||||||
param_key = optim_state_key_to_param_key[r0_optim_state_key]
|
param_key = optim_state_key_to_param_key[r0_optim_state_key]
|
||||||
if isinstance(param_key, int):
|
if isinstance(param_key, int):
|
||||||
assert param_key >= 0 and param_key < len(
|
assert param_key >= 0 and param_key < len(param_key_to_param), (
|
||||||
param_key_to_param
|
"Check the `param_key_to_param` construction"
|
||||||
), "Check the `param_key_to_param` construction"
|
)
|
||||||
# We cannot use FSDPState.compute_device as this API is a global view.
|
# We cannot use FSDPState.compute_device as this API is a global view.
|
||||||
device = _get_pg_default_device(group)
|
device = _get_pg_default_device(group)
|
||||||
num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)
|
num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)
|
||||||
|
@ -121,9 +121,9 @@ def _all_gather_dtensor(
|
|||||||
"""
|
"""
|
||||||
All gather a DTensor in its sharded dimension and return the local tensor.
|
All gather a DTensor in its sharded dimension and return the local tensor.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert root_mesh == tensor.device_mesh, (
|
||||||
root_mesh == tensor.device_mesh
|
"The device mesh of a tensor should be a root mesh."
|
||||||
), "The device mesh of a tensor should be a root mesh."
|
)
|
||||||
|
|
||||||
placements = list(copy.deepcopy(tensor.placements))
|
placements = list(copy.deepcopy(tensor.placements))
|
||||||
# FSDP placements: [Shard(0)] -> [Replicate()]
|
# FSDP placements: [Shard(0)] -> [Replicate()]
|
||||||
|
@ -466,9 +466,9 @@ def _local_pre_load_state_dict_hook(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
load_tensor = state_dict[fqn]
|
load_tensor = state_dict[fqn]
|
||||||
assert isinstance(
|
assert isinstance(load_tensor, ShardedTensor), (
|
||||||
load_tensor, ShardedTensor
|
"Tensors in local_state_dict should be ShardedTensor."
|
||||||
), "Tensors in local_state_dict should be ShardedTensor."
|
)
|
||||||
|
|
||||||
# Convert the ShardedTensor to a Tensor.
|
# Convert the ShardedTensor to a Tensor.
|
||||||
flat_param = _module_handle(fsdp_state, module).flat_param
|
flat_param = _module_handle(fsdp_state, module).flat_param
|
||||||
|
@ -143,9 +143,9 @@ class _ExecOrderTracer:
|
|||||||
named_params = list(module.named_parameters())
|
named_params = list(module.named_parameters())
|
||||||
curr_module = exec_info.curr_module
|
curr_module = exec_info.curr_module
|
||||||
if named_params:
|
if named_params:
|
||||||
assert (
|
assert curr_module in exec_info.module_to_param_usage_infos, (
|
||||||
curr_module in exec_info.module_to_param_usage_infos
|
"The current module should have already been processed by a patched `call_module`"
|
||||||
), "The current module should have already been processed by a patched `call_module`"
|
)
|
||||||
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
|
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
|
||||||
_ParamUsageInfo(module, named_params)
|
_ParamUsageInfo(module, named_params)
|
||||||
)
|
)
|
||||||
|
@ -185,9 +185,9 @@ def _unshard_fsdp_state_params(
|
|||||||
yield
|
yield
|
||||||
return
|
return
|
||||||
|
|
||||||
assert (
|
assert handle._training_state == HandleTrainingState.IDLE, (
|
||||||
handle._training_state == HandleTrainingState.IDLE
|
f"Expects the handle training to be IDLE but got {handle._training_state}"
|
||||||
), f"Expects the handle training to be IDLE but got {handle._training_state}"
|
)
|
||||||
|
|
||||||
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
|
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
|
||||||
|
|
||||||
|
@ -306,16 +306,21 @@ class FullStateDictConfig(StateDictConfig):
|
|||||||
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||||
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
|
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
|
||||||
>>> state = fsdp.state_dict()
|
>>> state = fsdp.state_dict()
|
||||||
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
|
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
|
||||||
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
|
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
|
||||||
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
|
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
|
||||||
>>> if dist.get_rank() == 0:
|
>>> if dist.get_rank() == 0:
|
||||||
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
|
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
|
||||||
>>> state_dict = torch.load("my_checkpoint.pt")
|
>>> state_dict = torch.load("my_checkpoint.pt")
|
||||||
>>> model.load_state_dict(state_dict)
|
>>> model.load_state_dict(state_dict)
|
||||||
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
|
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
|
||||||
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
|
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
|
||||||
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
|
>>> fsdp = FSDP(
|
||||||
|
... model,
|
||||||
|
... device_id=torch.cuda.current_device(),
|
||||||
|
... auto_wrap_policy=...,
|
||||||
|
... sync_module_states=True,
|
||||||
|
... )
|
||||||
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
|
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -723,9 +723,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||||||
if prev_state_dict_type is None:
|
if prev_state_dict_type is None:
|
||||||
prev_state_dict_type = submodule._state_dict_type
|
prev_state_dict_type = submodule._state_dict_type
|
||||||
else:
|
else:
|
||||||
assert (
|
assert prev_state_dict_type == submodule._state_dict_type, (
|
||||||
prev_state_dict_type == submodule._state_dict_type
|
"All FSDP modules should have the same state_dict_type."
|
||||||
), "All FSDP modules should have the same state_dict_type."
|
)
|
||||||
if prev_state_dict_config is None:
|
if prev_state_dict_config is None:
|
||||||
prev_state_dict_config = submodule._state_dict_config
|
prev_state_dict_config = submodule._state_dict_config
|
||||||
else:
|
else:
|
||||||
@ -738,7 +738,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
submodule._optim_state_dict_config,
|
submodule._optim_state_dict_config,
|
||||||
type(prev_optim_state_dict_config),
|
type(prev_optim_state_dict_config),
|
||||||
), "All FSDP modules must have the same type of optim_state_dict_config."
|
), (
|
||||||
|
"All FSDP modules must have the same type of optim_state_dict_config."
|
||||||
|
)
|
||||||
|
|
||||||
submodule._state_dict_type = state_dict_type
|
submodule._state_dict_type = state_dict_type
|
||||||
submodule._state_dict_config = state_dict_config
|
submodule._state_dict_config = state_dict_config
|
||||||
@ -2153,9 +2155,9 @@ def _get_param_to_fqn(
|
|||||||
"""
|
"""
|
||||||
param_to_param_names = _get_param_to_fqns(model)
|
param_to_param_names = _get_param_to_fqns(model)
|
||||||
for param_names in param_to_param_names.values():
|
for param_names in param_to_param_names.values():
|
||||||
assert (
|
assert len(param_names) > 0, (
|
||||||
len(param_names) > 0
|
"`_get_param_to_fqns()` should not construct empty lists"
|
||||||
), "`_get_param_to_fqns()` should not construct empty lists"
|
)
|
||||||
if len(param_names) > 1:
|
if len(param_names) > 1:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Each parameter should only map to one parameter name but got "
|
"Each parameter should only map to one parameter name but got "
|
||||||
|
@ -112,20 +112,16 @@ class ShardedGradScaler(GradScaler):
|
|||||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
|
def scale(self, outputs: torch.Tensor) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]:
|
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
|
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
|
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ...
|
||||||
...
|
|
||||||
|
|
||||||
def scale(
|
def scale(
|
||||||
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
|
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||||
@ -323,8 +319,10 @@ class ShardedGradScaler(GradScaler):
|
|||||||
if isinstance(new_scale, float):
|
if isinstance(new_scale, float):
|
||||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||||
else:
|
else:
|
||||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
|
reason = (
|
||||||
|
"new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
|
||||||
torch.FloatTensor with requires_grad=False."
|
torch.FloatTensor with requires_grad=False."
|
||||||
|
)
|
||||||
assert new_scale.device.type == self._device, reason
|
assert new_scale.device.type == self._device, reason
|
||||||
assert new_scale.numel() == 1, reason
|
assert new_scale.numel() == 1, reason
|
||||||
assert new_scale.requires_grad is False, reason
|
assert new_scale.requires_grad is False, reason
|
||||||
|
@ -61,9 +61,9 @@ def _post_order_apply(
|
|||||||
"Non-root modules should have their module name set but got "
|
"Non-root modules should have their module name set but got "
|
||||||
f"an empty module name for {module}"
|
f"an empty module name for {module}"
|
||||||
)
|
)
|
||||||
assert isinstance(
|
assert isinstance(optional_module, nn.Module), (
|
||||||
optional_module, nn.Module
|
f"fn should return None or an nn.Module but got {optional_module}"
|
||||||
), f"fn should return None or an nn.Module but got {optional_module}"
|
)
|
||||||
setattr(parent_module, module_name, optional_module)
|
setattr(parent_module, module_name, optional_module)
|
||||||
|
|
||||||
_post_order_apply_inner(root_module, "", None)
|
_post_order_apply_inner(root_module, "", None)
|
||||||
@ -575,9 +575,9 @@ class _ConfigAutoWrap:
|
|||||||
)
|
)
|
||||||
_ConfigAutoWrap.in_autowrap_context = True
|
_ConfigAutoWrap.in_autowrap_context = True
|
||||||
# Get and save the wrapper cls for the context.
|
# Get and save the wrapper cls for the context.
|
||||||
assert (
|
assert "wrapper_cls" in kwargs.keys(), (
|
||||||
"wrapper_cls" in kwargs.keys()
|
"Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
|
||||||
), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
|
)
|
||||||
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
|
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
|
||||||
del kwargs["wrapper_cls"]
|
del kwargs["wrapper_cls"]
|
||||||
# Save the rest.
|
# Save the rest.
|
||||||
|
@ -183,8 +183,7 @@ def parse_args(args):
|
|||||||
def launch(args):
|
def launch(args):
|
||||||
if args.no_python and not args.use_env:
|
if args.no_python and not args.use_env:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When using the '--no-python' flag,"
|
"When using the '--no-python' flag, you must also set the '--use-env' flag."
|
||||||
" you must also set the '--use-env' flag."
|
|
||||||
)
|
)
|
||||||
run(args)
|
run(args)
|
||||||
|
|
||||||
|
@ -39,7 +39,10 @@ _REMOTE_MODULE_PICKLED_ATTRIBUTES = (
|
|||||||
"module_rref",
|
"module_rref",
|
||||||
)
|
)
|
||||||
|
|
||||||
_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc]
|
_SerializedRemoteModule = collections.namedtuple( # type: ignore[misc]
|
||||||
|
"_SerializedRemoteModule",
|
||||||
|
_REMOTE_MODULE_PICKLED_ATTRIBUTES,
|
||||||
|
)
|
||||||
|
|
||||||
# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
|
# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
|
||||||
# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
|
# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
|
||||||
|
@ -26,15 +26,15 @@ sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
|
|||||||
|
|
||||||
|
|
||||||
def get_arg_return_types_from_interface(module_interface):
|
def get_arg_return_types_from_interface(module_interface):
|
||||||
assert getattr(
|
assert getattr(module_interface, "__torch_script_interface__", False), (
|
||||||
module_interface, "__torch_script_interface__", False
|
"Expect a TorchScript class interface decorated by @torch.jit.interface."
|
||||||
), "Expect a TorchScript class interface decorated by @torch.jit.interface."
|
)
|
||||||
qualified_name = torch._jit_internal._qualified_name(module_interface)
|
qualified_name = torch._jit_internal._qualified_name(module_interface)
|
||||||
cu = torch.jit._state._python_cu
|
cu = torch.jit._state._python_cu
|
||||||
module_interface_c = cu.get_interface(qualified_name)
|
module_interface_c = cu.get_interface(qualified_name)
|
||||||
assert (
|
assert "forward" in module_interface_c.getMethodNames(), (
|
||||||
"forward" in module_interface_c.getMethodNames()
|
f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
|
||||||
), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
|
)
|
||||||
method_schema = module_interface_c.getMethod("forward")
|
method_schema = module_interface_c.getMethod("forward")
|
||||||
|
|
||||||
arg_str_list = []
|
arg_str_list = []
|
||||||
|
@ -5,6 +5,7 @@ optimizer locally on the workers where the parameters live. The distributed
|
|||||||
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
|
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
|
||||||
apply the gradients on each worker.
|
apply the gradients on each worker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -44,10 +44,10 @@ def _apply_optimizer_in_backward(
|
|||||||
param_1 = next(params_generator)
|
param_1 = next(params_generator)
|
||||||
remainder_params = list(params_generator)
|
remainder_params = list(params_generator)
|
||||||
|
|
||||||
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
|
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02})
|
||||||
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
|
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04})
|
||||||
|
|
||||||
model(...).sum().backward() # after backward, parameters will already
|
model(...).sum().backward() # after backward, parameters will already
|
||||||
# have their registered optimizer(s) applied.
|
# have their registered optimizer(s) applied.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -111,7 +111,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Opt
|
|||||||
List[torch.optim.Optimizer]: the in-backward optimizers.
|
List[torch.optim.Optimizer]: the in-backward optimizers.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
|
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
|
||||||
optims = _get_optimizers_in_backward(model)
|
optims = _get_optimizers_in_backward(model)
|
||||||
"""
|
"""
|
||||||
optims: list[torch.optim.Optimizer] = []
|
optims: list[torch.optim.Optimizer] = []
|
||||||
|
@ -147,12 +147,10 @@ class _NamedOptimizer(optim.Optimizer):
|
|||||||
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
|
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def step(self, closure: None = ...) -> None:
|
def step(self, closure: None = ...) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def step(self, closure: Callable[[], float]) -> float:
|
def step(self, closure: Callable[[], float]) -> float: ...
|
||||||
...
|
|
||||||
|
|
||||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
r"""Zero Redundancy Optimizer."""
|
r"""Zero Redundancy Optimizer."""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
@ -262,9 +263,9 @@ class _OverlapInfo:
|
|||||||
meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
|
meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
|
||||||
in preparation for the next iteration.
|
in preparation for the next iteration.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert len(self.broadcast_handles) == self.num_bucket_assignments, (
|
||||||
len(self.broadcast_handles) == self.num_bucket_assignments
|
f"Missing at least one broadcast handle on rank {dist.get_rank()}"
|
||||||
), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
|
)
|
||||||
_ = [x.wait() for x in self.broadcast_handles]
|
_ = [x.wait() for x in self.broadcast_handles]
|
||||||
self.broadcast_handles.clear()
|
self.broadcast_handles.clear()
|
||||||
|
|
||||||
@ -909,9 +910,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
params_per_rank = overlap_info.params_per_rank
|
params_per_rank = overlap_info.params_per_rank
|
||||||
offsets = overlap_info.offsets
|
offsets = overlap_info.offsets
|
||||||
|
|
||||||
self._bucket_assignments_per_rank_cache[assigned_rank][
|
self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = (
|
||||||
bucket_index
|
_DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
|
||||||
] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
|
)
|
||||||
if self.global_rank == assigned_rank:
|
if self.global_rank == assigned_rank:
|
||||||
offsets[bucket_index] = len(params_per_rank[assigned_rank])
|
offsets[bucket_index] = len(params_per_rank[assigned_rank])
|
||||||
params_per_rank[assigned_rank].extend(bucket_params)
|
params_per_rank[assigned_rank].extend(bucket_params)
|
||||||
@ -927,9 +928,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
mapping bucket indices to :class:`_DDPBucketAssignment` s for each
|
mapping bucket indices to :class:`_DDPBucketAssignment` s for each
|
||||||
rank.
|
rank.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self._overlap_with_ddp, (
|
||||||
self._overlap_with_ddp
|
"`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
|
||||||
), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
|
)
|
||||||
if len(self._bucket_assignments_per_rank_cache) > 0:
|
if len(self._bucket_assignments_per_rank_cache) > 0:
|
||||||
return self._bucket_assignments_per_rank_cache
|
return self._bucket_assignments_per_rank_cache
|
||||||
|
|
||||||
@ -1076,9 +1077,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
"Specifying `gradients` should not "
|
"Specifying `gradients` should not "
|
||||||
"be used when `overlap_with_ddp=False`"
|
"be used when `overlap_with_ddp=False`"
|
||||||
)
|
)
|
||||||
assert (
|
assert closure is None, (
|
||||||
closure is None
|
"`closure` is not supported when using a local functional optimizer"
|
||||||
), "`closure` is not supported when using a local functional optimizer"
|
)
|
||||||
loss = self.optim.step(gradients=gradients)
|
loss = self.optim.step(gradients=gradients)
|
||||||
|
|
||||||
# Sync any updated attributes in the local optimizer to the exposed
|
# Sync any updated attributes in the local optimizer to the exposed
|
||||||
@ -1221,9 +1222,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
for rank, local_state_dict in enumerate(self._all_state_dicts):
|
for rank, local_state_dict in enumerate(self._all_state_dicts):
|
||||||
local_param_groups = local_state_dict["param_groups"]
|
local_param_groups = local_state_dict["param_groups"]
|
||||||
global_param_groups = self._partition_parameters()[rank]
|
global_param_groups = self._partition_parameters()[rank]
|
||||||
assert len(local_param_groups) == len(
|
assert len(local_param_groups) == len(global_param_groups), (
|
||||||
global_param_groups
|
"Mismatch between number of local and global parameter groups"
|
||||||
), "Mismatch between number of local and global parameter groups"
|
)
|
||||||
|
|
||||||
for local_param_group, global_param_group in zip(
|
for local_param_group, global_param_group in zip(
|
||||||
local_param_groups, global_param_groups
|
local_param_groups, global_param_groups
|
||||||
@ -1233,9 +1234,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
local_param_indices = local_param_group["params"]
|
local_param_indices = local_param_group["params"]
|
||||||
global_params = global_param_group["params"]
|
global_params = global_param_group["params"]
|
||||||
|
|
||||||
assert len(local_param_indices) == len(
|
assert len(local_param_indices) == len(global_params), (
|
||||||
global_params
|
"Mismatch between number of local and global parameters in parameter group"
|
||||||
), "Mismatch between number of local and global parameters in parameter group"
|
)
|
||||||
for local_param_index, global_param in zip(
|
for local_param_index, global_param in zip(
|
||||||
local_param_indices, global_params
|
local_param_indices, global_params
|
||||||
):
|
):
|
||||||
@ -1268,9 +1269,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
dst_param_groups (list[dict]): parameter groups giving the
|
dst_param_groups (list[dict]): parameter groups giving the
|
||||||
attribute settings to set.
|
attribute settings to set.
|
||||||
"""
|
"""
|
||||||
assert len(src_param_groups) == len(
|
assert len(src_param_groups) == len(dst_param_groups), (
|
||||||
dst_param_groups
|
"Mismatch between number of source and destination parameter groups"
|
||||||
), "Mismatch between number of source and destination parameter groups"
|
)
|
||||||
for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
|
for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
|
||||||
# Sync all attributes except the parameters
|
# Sync all attributes except the parameters
|
||||||
for attr in filter(lambda x: x != "params", src_param_group.keys()):
|
for attr in filter(lambda x: x != "params", src_param_group.keys()):
|
||||||
@ -1479,9 +1480,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
|
|
||||||
The local optimizer is saved in ``self.optim``.
|
The local optimizer is saved in ``self.optim``.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self._optim_constructor is not None, (
|
||||||
self._optim_constructor is not None
|
"The local optimizer class has not been set"
|
||||||
), "The local optimizer class has not been set"
|
)
|
||||||
|
|
||||||
param_groups = self._partition_parameters()[self.rank]
|
param_groups = self._partition_parameters()[self.rank]
|
||||||
# `overlap_with_ddp=True` requires a local functional optimizer
|
# `overlap_with_ddp=True` requires a local functional optimizer
|
||||||
@ -1508,7 +1509,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
"error due to an empty parameter list",
|
"error due to an empty parameter list",
|
||||||
self._optim_constructor,
|
self._optim_constructor,
|
||||||
)
|
)
|
||||||
self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef]
|
self.optim: Any = self._optim_constructor(
|
||||||
|
params, **self._optim_defaults
|
||||||
|
) # type: ignore[no-redef]
|
||||||
|
|
||||||
# Log information about the DDP and ZeRO bucketing
|
# Log information about the DDP and ZeRO bucketing
|
||||||
if dist.get_debug_level() != dist.DebugLevel.OFF:
|
if dist.get_debug_level() != dist.DebugLevel.OFF:
|
||||||
@ -1531,7 +1534,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||||||
else:
|
else:
|
||||||
# NOTE: Passing `param_groups` into the local optimizer constructor
|
# NOTE: Passing `param_groups` into the local optimizer constructor
|
||||||
# bypasses the empty parameter list check
|
# bypasses the empty parameter list check
|
||||||
self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef]
|
self.optim: Optimizer = self._optim_constructor(
|
||||||
|
param_groups, **self._optim_defaults
|
||||||
|
) # type: ignore[no-redef]
|
||||||
|
|
||||||
# TODO: Manually add `self.param_groups` if using a functional
|
# TODO: Manually add `self.param_groups` if using a functional
|
||||||
# optimizer; remove this if/when the functional optimizers support
|
# optimizer; remove this if/when the functional optimizers support
|
||||||
|
@ -123,12 +123,11 @@ def _insert_stage_symbolic_backward(
|
|||||||
# getitem calls. If we have a target other than getitem in this
|
# getitem calls. If we have a target other than getitem in this
|
||||||
# (forward-only) code, there is a bug.
|
# (forward-only) code, there is a bug.
|
||||||
assert node.target == operator.getitem, (
|
assert node.target == operator.getitem, (
|
||||||
"Found non-getitem call in forward pass. "
|
"Found non-getitem call in forward pass. Please report a bug to PiPPy"
|
||||||
"Please report a bug to PiPPy"
|
)
|
||||||
|
assert len(node.args) == 2, (
|
||||||
|
"Found malformed getitem call. Please report a bug to PiPPy"
|
||||||
)
|
)
|
||||||
assert (
|
|
||||||
len(node.args) == 2
|
|
||||||
), "Found malformed getitem call. Please report a bug to PiPPy"
|
|
||||||
indexed_value, node_idx = tuple(node.args)
|
indexed_value, node_idx = tuple(node.args)
|
||||||
|
|
||||||
# indexed_value is a collection that we are indexing into. It could
|
# indexed_value is a collection that we are indexing into. It could
|
||||||
@ -249,8 +248,8 @@ class LossWrapper(torch.nn.Module):
|
|||||||
targets value into the loss function, and get and return the loss value, which will
|
targets value into the loss function, and get and return the loss value, which will
|
||||||
be backpropagated by PiPPy. The above class would then be instantiated like::
|
be backpropagated by PiPPy. The above class would then be instantiated like::
|
||||||
|
|
||||||
model = ... # instantiate the model
|
model = ... # instantiate the model
|
||||||
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
|
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
|
||||||
|
|
||||||
wrapper = MyModelWrapper(model, loss_fn)
|
wrapper = MyModelWrapper(model, loss_fn)
|
||||||
pipe = Pipe.from_tracing(wrapper, ...)
|
pipe = Pipe.from_tracing(wrapper, ...)
|
||||||
@ -818,9 +817,9 @@ class Pipe(torch.nn.Module):
|
|||||||
|
|
||||||
# Get submodule
|
# Get submodule
|
||||||
callee = root.get_submodule(callee_name)
|
callee = root.get_submodule(callee_name)
|
||||||
assert not hasattr(
|
assert not hasattr(callee, param_fqn), (
|
||||||
callee, param_fqn
|
f"Module {callee_name} already has a parameter named {param_fqn}"
|
||||||
), f"Module {callee_name} already has a parameter named {param_fqn}"
|
)
|
||||||
|
|
||||||
# Assign the parameter to the submodule
|
# Assign the parameter to the submodule
|
||||||
if is_buffer:
|
if is_buffer:
|
||||||
@ -979,7 +978,7 @@ class Pipe(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
logger.debug("Pipeline is in inference mode, backward pass not generated")
|
logger.debug("Pipeline is in inference mode, backward pass not generated")
|
||||||
|
|
||||||
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
|
logger.debug(f"Full pipe model:\n{split}") # noqa: G004
|
||||||
|
|
||||||
return Pipe(
|
return Pipe(
|
||||||
split,
|
split,
|
||||||
@ -1184,7 +1183,7 @@ def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
|
|||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"Specified target {qualname} referenced "
|
f"Specified target {qualname} referenced "
|
||||||
f'nonexistent module {".".join(atoms[: i + 1])}'
|
f"nonexistent module {'.'.join(atoms[: i + 1])}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
mod_to_wrap = getattr(predecessor_module, atoms[-1])
|
mod_to_wrap = getattr(predecessor_module, atoms[-1])
|
||||||
|
@ -306,17 +306,17 @@ def stage_backward(
|
|||||||
if isinstance(output_val, torch.Tensor):
|
if isinstance(output_val, torch.Tensor):
|
||||||
if not output_val.requires_grad and output_val.grad_fn is None:
|
if not output_val.requires_grad and output_val.grad_fn is None:
|
||||||
return
|
return
|
||||||
assert isinstance(
|
assert isinstance(grad_val, (torch.Tensor, type(None))), (
|
||||||
grad_val, (torch.Tensor, type(None))
|
f"Expected Tensor or None gradient but got {type(grad_val)}"
|
||||||
), f"Expected Tensor or None gradient but got {type(grad_val)}"
|
)
|
||||||
stage_output_tensors.append(output_val)
|
stage_output_tensors.append(output_val)
|
||||||
output_grad_tensors.append(grad_val)
|
output_grad_tensors.append(grad_val)
|
||||||
elif isinstance(output_val, (tuple, list)):
|
elif isinstance(output_val, (tuple, list)):
|
||||||
if grad_val is None:
|
if grad_val is None:
|
||||||
return
|
return
|
||||||
assert isinstance(
|
assert isinstance(grad_val, (tuple, list)), (
|
||||||
grad_val, (tuple, list)
|
f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
|
||||||
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
|
)
|
||||||
assert len(output_val) == len(grad_val)
|
assert len(output_val) == len(grad_val)
|
||||||
for ov, gv in zip(output_val, grad_val):
|
for ov, gv in zip(output_val, grad_val):
|
||||||
extract_tensors_with_grads(
|
extract_tensors_with_grads(
|
||||||
@ -350,7 +350,8 @@ def stage_backward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
torch.autograd.backward(
|
torch.autograd.backward(
|
||||||
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
|
stage_output_tensors,
|
||||||
|
grad_tensors=output_grad_tensors, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract gradients wrt the input values
|
# Extract gradients wrt the input values
|
||||||
|
@ -140,9 +140,9 @@ def _shard_dict_of_args(
|
|||||||
real_num_chunks = num_chunks
|
real_num_chunks = num_chunks
|
||||||
first_tensor = True
|
first_tensor = True
|
||||||
|
|
||||||
assert len(args_dict) == len(
|
assert len(args_dict) == len(args_chunk_spec), (
|
||||||
args_chunk_spec
|
f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
||||||
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
)
|
||||||
|
|
||||||
for arg_key, arg in args_dict.items():
|
for arg_key, arg in args_dict.items():
|
||||||
flat, spec = tree_flatten(arg)
|
flat, spec = tree_flatten(arg)
|
||||||
|
@ -706,7 +706,9 @@ class Schedule1F1B(PipelineScheduleSingle):
|
|||||||
recv_work.wait()
|
recv_work.wait()
|
||||||
|
|
||||||
# Compute
|
# Compute
|
||||||
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
output = self._stage.forward_one_chunk(
|
||||||
|
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
|
||||||
|
) # type: ignore[index]
|
||||||
|
|
||||||
# Clear previous chunk's forward sends (hopefully they have well
|
# Clear previous chunk's forward sends (hopefully they have well
|
||||||
# finished, otherwise, we are heavily communication bound, in which
|
# finished, otherwise, we are heavily communication bound, in which
|
||||||
@ -762,7 +764,9 @@ class Schedule1F1B(PipelineScheduleSingle):
|
|||||||
fuse_work.wait()
|
fuse_work.wait()
|
||||||
|
|
||||||
# Now do the fwd
|
# Now do the fwd
|
||||||
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
output = self._stage.forward_one_chunk(
|
||||||
|
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
|
||||||
|
) # type: ignore[index]
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
||||||
@ -992,9 +996,9 @@ def _add_send_recv(
|
|||||||
progress = False
|
progress = False
|
||||||
# go in order of ranks even if dict keys aren't ordered
|
# go in order of ranks even if dict keys aren't ordered
|
||||||
for rank in sorted(compute_actions):
|
for rank in sorted(compute_actions):
|
||||||
assert (
|
assert len(compute_actions[rank]) > 0, (
|
||||||
len(compute_actions[rank]) > 0
|
f"{rank=}, {len(compute_actions[rank])=}"
|
||||||
), f"{rank=}, {len(compute_actions[rank])=}"
|
)
|
||||||
action = compute_actions[rank][0]
|
action = compute_actions[rank][0]
|
||||||
|
|
||||||
if not _ready_to_schedule(action, prev_actions[rank]):
|
if not _ready_to_schedule(action, prev_actions[rank]):
|
||||||
@ -1026,9 +1030,9 @@ def _validate_schedule(
|
|||||||
num_stages: int,
|
num_stages: int,
|
||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
) -> dict[int, int]:
|
) -> dict[int, int]:
|
||||||
assert (
|
assert len(actions) == pp_group_size, (
|
||||||
len(actions) == pp_group_size
|
f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
|
||||||
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
|
)
|
||||||
for rank in range(pp_group_size):
|
for rank in range(pp_group_size):
|
||||||
assert rank in actions, f"Schedule is missing actions for rank {rank}"
|
assert rank in actions, f"Schedule is missing actions for rank {rank}"
|
||||||
|
|
||||||
@ -1048,36 +1052,36 @@ def _validate_schedule(
|
|||||||
for action in actions[rank]:
|
for action in actions[rank]:
|
||||||
if action is None:
|
if action is None:
|
||||||
continue
|
continue
|
||||||
assert isinstance(
|
assert isinstance(action, _Action), (
|
||||||
action, _Action
|
f"Got an invalid action: {action}, expected instance of _Action"
|
||||||
), f"Got an invalid action: {action}, expected instance of _Action"
|
)
|
||||||
s_id = action.stage_index
|
s_id = action.stage_index
|
||||||
ctype = action.computation_type
|
ctype = action.computation_type
|
||||||
mb_id = action.microbatch_index
|
mb_id = action.microbatch_index
|
||||||
if ctype == F:
|
if ctype == F:
|
||||||
stage_actions[s_id][F].add(mb_id)
|
stage_actions[s_id][F].add(mb_id)
|
||||||
elif ctype == B:
|
elif ctype == B:
|
||||||
assert (
|
assert mb_id in stage_actions[s_id][F], (
|
||||||
mb_id in stage_actions[s_id][F]
|
f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||||
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
)
|
||||||
stage_actions[s_id][B].add(mb_id)
|
stage_actions[s_id][B].add(mb_id)
|
||||||
elif ctype == I:
|
elif ctype == I:
|
||||||
assert (
|
assert mb_id in stage_actions[s_id][F], (
|
||||||
mb_id in stage_actions[s_id][F]
|
f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||||
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
|
)
|
||||||
stage_actions[s_id][I].add(mb_id)
|
stage_actions[s_id][I].add(mb_id)
|
||||||
elif ctype == W:
|
elif ctype == W:
|
||||||
assert (
|
assert mb_id in stage_actions[s_id][I], (
|
||||||
mb_id in stage_actions[s_id][I]
|
f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"
|
||||||
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"
|
)
|
||||||
stage_actions[s_id][W].add(mb_id)
|
stage_actions[s_id][W].add(mb_id)
|
||||||
if s_id not in stage_index_to_rank_mapping:
|
if s_id not in stage_index_to_rank_mapping:
|
||||||
stage_index_to_rank_mapping[s_id] = rank
|
stage_index_to_rank_mapping[s_id] = rank
|
||||||
else:
|
else:
|
||||||
existing_rank = stage_index_to_rank_mapping[s_id]
|
existing_rank = stage_index_to_rank_mapping[s_id]
|
||||||
assert (
|
assert rank == existing_rank, (
|
||||||
rank == existing_rank
|
f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
|
||||||
), f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
|
)
|
||||||
|
|
||||||
for s_id in stage_actions:
|
for s_id in stage_actions:
|
||||||
f_mb = len(stage_actions[s_id][F])
|
f_mb = len(stage_actions[s_id][F])
|
||||||
@ -1085,14 +1089,14 @@ def _validate_schedule(
|
|||||||
i_mb = len(stage_actions[s_id][I])
|
i_mb = len(stage_actions[s_id][I])
|
||||||
w_mb = len(stage_actions[s_id][W])
|
w_mb = len(stage_actions[s_id][W])
|
||||||
|
|
||||||
assert (
|
assert f_mb == num_microbatches, (
|
||||||
f_mb == num_microbatches
|
f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
|
||||||
), f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
|
)
|
||||||
|
|
||||||
assert (
|
assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, (
|
||||||
b_mb + (i_mb + w_mb) // 2 == num_microbatches
|
f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
|
||||||
), f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
|
|
||||||
but got B={b_mb}, I={i_mb}, W={w_mb}"
|
but got B={b_mb}, I={i_mb}, W={w_mb}"
|
||||||
|
)
|
||||||
return stage_index_to_rank_mapping
|
return stage_index_to_rank_mapping
|
||||||
|
|
||||||
|
|
||||||
@ -1289,9 +1293,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||||||
computation_type = action.computation_type
|
computation_type = action.computation_type
|
||||||
mb_index = action.microbatch_index
|
mb_index = action.microbatch_index
|
||||||
stage_index = action.stage_index
|
stage_index = action.stage_index
|
||||||
assert (
|
assert mb_index is not None, (
|
||||||
mb_index is not None
|
"All currently supported action types require valid microbatch_index"
|
||||||
), "All currently supported action types require valid microbatch_index"
|
)
|
||||||
if computation_type == _ComputationType.FORWARD:
|
if computation_type == _ComputationType.FORWARD:
|
||||||
# perform forward computation
|
# perform forward computation
|
||||||
stage = stage_index_to_stage[stage_index]
|
stage = stage_index_to_stage[stage_index]
|
||||||
@ -1362,9 +1366,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||||||
computation_type = prev_rank_action.computation_type
|
computation_type = prev_rank_action.computation_type
|
||||||
mb_index = prev_rank_action.microbatch_index
|
mb_index = prev_rank_action.microbatch_index
|
||||||
stage_index = prev_rank_action.stage_index
|
stage_index = prev_rank_action.stage_index
|
||||||
assert (
|
assert mb_index is not None, (
|
||||||
mb_index is not None
|
"All currently supported action types require valid microbatch_index"
|
||||||
), "All currently supported action types require valid microbatch_index"
|
)
|
||||||
# Only handle sends for the forward from a previous rank
|
# Only handle sends for the forward from a previous rank
|
||||||
if computation_type == _ComputationType.FORWARD:
|
if computation_type == _ComputationType.FORWARD:
|
||||||
# If not the last stage, then receive fwd activations
|
# If not the last stage, then receive fwd activations
|
||||||
@ -1393,9 +1397,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||||||
computation_type = next_rank_action.computation_type
|
computation_type = next_rank_action.computation_type
|
||||||
mb_index = next_rank_action.microbatch_index
|
mb_index = next_rank_action.microbatch_index
|
||||||
stage_index = next_rank_action.stage_index
|
stage_index = next_rank_action.stage_index
|
||||||
assert (
|
assert mb_index is not None, (
|
||||||
mb_index is not None
|
"All currently supported action types require valid microbatch_index"
|
||||||
), "All currently supported action types require valid microbatch_index"
|
)
|
||||||
# Only handle receives for the backwards from a next rank
|
# Only handle receives for the backwards from a next rank
|
||||||
if computation_type in (FORWARD, BACKWARD_WEIGHT):
|
if computation_type in (FORWARD, BACKWARD_WEIGHT):
|
||||||
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
||||||
@ -1503,9 +1507,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||||||
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
|
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
|
||||||
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
|
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
|
||||||
# that it does not exist if it was created from a compute_comms schedule.
|
# that it does not exist if it was created from a compute_comms schedule.
|
||||||
assert (
|
assert self.pipeline_order_with_comms is not None, (
|
||||||
self.pipeline_order_with_comms is not None
|
"Must initialize compute_comms schedule before dump_csv"
|
||||||
), "Must initialize compute_comms schedule before dump_csv"
|
)
|
||||||
with open(filename, "w", newline="") as csvfile:
|
with open(filename, "w", newline="") as csvfile:
|
||||||
writer = csv.writer(csvfile)
|
writer = csv.writer(csvfile)
|
||||||
for rank in self.pipeline_order_with_comms:
|
for rank in self.pipeline_order_with_comms:
|
||||||
@ -1541,9 +1545,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||||||
stage.stage_index: stage for stage in self._stages
|
stage.stage_index: stage for stage in self._stages
|
||||||
}
|
}
|
||||||
|
|
||||||
assert (
|
assert self.pipeline_order_with_comms is not None, (
|
||||||
self.pipeline_order_with_comms is not None
|
"Must call _load_actions() before calling _step_microbatches()"
|
||||||
), "Must call _load_actions() before calling _step_microbatches()"
|
)
|
||||||
|
|
||||||
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
|
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
|
||||||
bwd_recv_ops: dict[tuple[int, int], Work] = {}
|
bwd_recv_ops: dict[tuple[int, int], Work] = {}
|
||||||
@ -1562,9 +1566,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||||||
unshard_ops[stage_idx].wait()
|
unshard_ops[stage_idx].wait()
|
||||||
del unshard_ops[stage_idx]
|
del unshard_ops[stage_idx]
|
||||||
unsharded_stages.add(stage_idx)
|
unsharded_stages.add(stage_idx)
|
||||||
assert (
|
assert stage_idx in unsharded_stages, (
|
||||||
stage_idx in unsharded_stages
|
f"Attempted to compute on sharded {stage_idx=}"
|
||||||
), f"Attempted to compute on sharded {stage_idx=}"
|
)
|
||||||
|
|
||||||
# count either full_backward or backward_weight together, to determine when to sync DP grads
|
# count either full_backward or backward_weight together, to determine when to sync DP grads
|
||||||
backward_counter: Counter[int] = Counter()
|
backward_counter: Counter[int] = Counter()
|
||||||
@ -1606,7 +1610,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||||||
assert (
|
assert (
|
||||||
stage_idx,
|
stage_idx,
|
||||||
mb_index,
|
mb_index,
|
||||||
) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward"
|
) not in fwd_recv_ops, (
|
||||||
|
"Recv twice for {stage_idx=} {mb_index=} without executing forward"
|
||||||
|
)
|
||||||
fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
||||||
stage.get_fwd_recv_ops(mb_index)
|
stage.get_fwd_recv_ops(mb_index)
|
||||||
)
|
)
|
||||||
@ -1614,7 +1620,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||||||
assert (
|
assert (
|
||||||
stage_idx,
|
stage_idx,
|
||||||
mb_index,
|
mb_index,
|
||||||
) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward"
|
) not in bwd_recv_ops, (
|
||||||
|
"Recv twice for {stage_idx=} {mb_index=} without executing backward"
|
||||||
|
)
|
||||||
bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
||||||
stage.get_bwd_recv_ops(mb_index)
|
stage.get_bwd_recv_ops(mb_index)
|
||||||
)
|
)
|
||||||
@ -1627,12 +1635,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||||||
unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator]
|
unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator]
|
||||||
elif comp_type == RESHARD:
|
elif comp_type == RESHARD:
|
||||||
if stage_uses_fsdp:
|
if stage_uses_fsdp:
|
||||||
assert (
|
assert stage_idx in unsharded_stages, (
|
||||||
stage_idx in unsharded_stages
|
f"Resharding {stage_idx=} without unsharding"
|
||||||
), f"Resharding {stage_idx=} without unsharding"
|
)
|
||||||
assert (
|
assert stage_idx not in unshard_ops, (
|
||||||
stage_idx not in unshard_ops
|
f"Resharding {stage_idx=} before finishing unshard"
|
||||||
), f"Resharding {stage_idx=} before finishing unshard"
|
)
|
||||||
stage.submod.reshard() # type: ignore[operator]
|
stage.submod.reshard() # type: ignore[operator]
|
||||||
elif comp_type == FORWARD:
|
elif comp_type == FORWARD:
|
||||||
if stage_uses_fsdp:
|
if stage_uses_fsdp:
|
||||||
@ -1739,7 +1747,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||||||
)
|
)
|
||||||
# TODO(whc) what is the best practice for printing a multiline log?
|
# TODO(whc) what is the best practice for printing a multiline log?
|
||||||
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
|
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
|
||||||
print(_format_pipeline_order(self.pipeline_order_with_comms, error_step_number=time_step)) # type: ignore[arg-type]
|
print(
|
||||||
|
_format_pipeline_order(
|
||||||
|
self.pipeline_order_with_comms, # type: ignore[arg-type]
|
||||||
|
error_step_number=time_step,
|
||||||
|
)
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
|
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
|
||||||
|
@ -243,16 +243,16 @@ class _PipelineStageBase(ABC):
|
|||||||
configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
|
configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
|
||||||
which could show up as hangs, silent corruption, or other errors.
|
which could show up as hangs, silent corruption, or other errors.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self._outputs_meta is None, (
|
||||||
self._outputs_meta is None
|
"Attempting to reconfigure output_meta, which is not supported"
|
||||||
), "Attempting to reconfigure output_meta, which is not supported"
|
)
|
||||||
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
|
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
|
||||||
|
|
||||||
def get_outputs_meta(self) -> tuple[torch.Tensor, ...]:
|
def get_outputs_meta(self) -> tuple[torch.Tensor, ...]:
|
||||||
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
|
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
|
||||||
assert (
|
assert self._outputs_meta is not None, (
|
||||||
self._outputs_meta is not None
|
"Attempted to get_outputs_meta() without configuring output meta"
|
||||||
), "Attempted to get_outputs_meta() without configuring output meta"
|
)
|
||||||
return self._outputs_meta
|
return self._outputs_meta
|
||||||
|
|
||||||
def _create_grad_send_info(
|
def _create_grad_send_info(
|
||||||
@ -358,12 +358,12 @@ class _PipelineStageBase(ABC):
|
|||||||
prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs)
|
prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs)
|
||||||
|
|
||||||
for info, tensor in zip(recv_infos, prev_stage_outputs):
|
for info, tensor in zip(recv_infos, prev_stage_outputs):
|
||||||
assert isinstance(
|
assert isinstance(tensor, torch.Tensor), (
|
||||||
tensor, torch.Tensor
|
f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
||||||
), f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
)
|
||||||
assert isinstance(
|
assert isinstance(info, _RecvInfo), (
|
||||||
info, _RecvInfo
|
"set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo"
|
||||||
), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo"
|
)
|
||||||
|
|
||||||
# We don't need to do a data copy here, since we can directly pass the activation tensor reference from
|
# We don't need to do a data copy here, since we can directly pass the activation tensor reference from
|
||||||
# one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve
|
# one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve
|
||||||
@ -376,9 +376,9 @@ class _PipelineStageBase(ABC):
|
|||||||
"""
|
"""
|
||||||
Returns the input grad tensors for this stage, which correspond to the stage inputs during forward.
|
Returns the input grad tensors for this stage, which correspond to the stage inputs during forward.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self.has_backward, (
|
||||||
self.has_backward
|
"can't steal_bwd_input if this stage doesn't have backward"
|
||||||
), "can't steal_bwd_input if this stage doesn't have backward"
|
)
|
||||||
assert not self.is_first, "can't get bwd output if this stage is first"
|
assert not self.is_first, "can't get bwd output if this stage is first"
|
||||||
|
|
||||||
self._check_chunk_id(mb_index)
|
self._check_chunk_id(mb_index)
|
||||||
@ -391,22 +391,22 @@ class _PipelineStageBase(ABC):
|
|||||||
Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
|
Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
|
||||||
Does not detach or set '_requires_grad'.
|
Does not detach or set '_requires_grad'.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(next_stage_bwd_outputs, tuple), (
|
||||||
next_stage_bwd_outputs, tuple
|
f"Expected tuple, got {type(next_stage_bwd_outputs)}"
|
||||||
), f"Expected tuple, got {type(next_stage_bwd_outputs)}"
|
)
|
||||||
|
|
||||||
assert (
|
assert self.has_backward, (
|
||||||
self.has_backward
|
"can't set bwd input if this stage doesn't have backward"
|
||||||
), "can't set bwd input if this stage doesn't have backward"
|
)
|
||||||
assert not self.is_last, "can't set bwd input if this stage is last"
|
assert not self.is_last, "can't set bwd input if this stage is last"
|
||||||
recv_infos = self.grad_recv_info[mb_index]
|
recv_infos = self.grad_recv_info[mb_index]
|
||||||
for info, tensor in zip(recv_infos, next_stage_bwd_outputs):
|
for info, tensor in zip(recv_infos, next_stage_bwd_outputs):
|
||||||
assert isinstance(
|
assert isinstance(tensor, torch.Tensor), (
|
||||||
tensor, torch.Tensor
|
f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
||||||
), f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
)
|
||||||
assert isinstance(
|
assert isinstance(info, _RecvInfo), (
|
||||||
info, _RecvInfo
|
f"Expected a recv info, got {type(info)}"
|
||||||
), f"Expected a recv info, got {type(info)}"
|
)
|
||||||
info.buffer = tensor
|
info.buffer = tensor
|
||||||
|
|
||||||
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
|
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
|
||||||
@ -1053,9 +1053,9 @@ class _PipelineStage(_PipelineStageBase):
|
|||||||
# If the input is a getitem, we need to go deeper
|
# If the input is a getitem, we need to go deeper
|
||||||
arg_node = arg_node.args[0]
|
arg_node = arg_node.args[0]
|
||||||
|
|
||||||
assert (
|
assert arg_node.op == "call_module", (
|
||||||
arg_node.op == "call_module"
|
f"Expecting call_module, got {arg_node.op}"
|
||||||
), f"Expecting call_module, got {arg_node.op}"
|
)
|
||||||
src_stage = self.get_stage_index_of_submod(arg_node.name)
|
src_stage = self.get_stage_index_of_submod(arg_node.name)
|
||||||
|
|
||||||
# Create a receive buffer for this placeholder
|
# Create a receive buffer for this placeholder
|
||||||
@ -1081,7 +1081,8 @@ class _PipelineStage(_PipelineStageBase):
|
|||||||
args_recv_info: list[InputInfo] = []
|
args_recv_info: list[InputInfo] = []
|
||||||
# Filter out placeholder nodes from `self.submod` (a GraphModule)
|
# Filter out placeholder nodes from `self.submod` (a GraphModule)
|
||||||
placeholders = filter( # type: ignore[var-annotated]
|
placeholders = filter( # type: ignore[var-annotated]
|
||||||
lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr]
|
lambda node: node.op == "placeholder", # type: ignore[arg-type]
|
||||||
|
self.submod.graph.nodes, # type: ignore[arg-type,union-attr]
|
||||||
)
|
)
|
||||||
# `placeholders` are nodes internal to submod.
|
# `placeholders` are nodes internal to submod.
|
||||||
# `self.node.args` are dependency nodes in the outer graph.
|
# `self.node.args` are dependency nodes in the outer graph.
|
||||||
@ -1300,9 +1301,9 @@ class PipelineStage(_PipelineStageBase):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Failed to perform pipeline shape inference- are your inputs on the same device as your module?"
|
"Failed to perform pipeline shape inference- are your inputs on the same device as your module?"
|
||||||
) from e
|
) from e
|
||||||
assert (
|
assert output_args is not None, (
|
||||||
output_args is not None
|
"If passing input_args, also pass output_args to override shape inference"
|
||||||
), "If passing input_args, also pass output_args to override shape inference"
|
)
|
||||||
self._configure_outputs_meta(
|
self._configure_outputs_meta(
|
||||||
(output_args,) if isinstance(output_args, torch.Tensor) else output_args
|
(output_args,) if isinstance(output_args, torch.Tensor) else output_args
|
||||||
)
|
)
|
||||||
@ -1346,9 +1347,9 @@ class PipelineStage(_PipelineStageBase):
|
|||||||
)
|
)
|
||||||
args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args)
|
args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert len(args) == 0, (
|
||||||
len(args) == 0
|
"Can't supply input args for shape inference on non-first stage"
|
||||||
), "Can't supply input args for shape inference on non-first stage"
|
)
|
||||||
objects = [None]
|
objects = [None]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Shape inference: stage %s receiving from stage %s",
|
"Shape inference: stage %s receiving from stage %s",
|
||||||
|
@ -80,9 +80,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
|
|||||||
world_size = world_size_opt
|
world_size = world_size_opt
|
||||||
if rank != -1 or world_size != -1 or world_size_opt is None:
|
if rank != -1 or world_size != -1 or world_size_opt is None:
|
||||||
query_dict = _query_to_dict(result.query)
|
query_dict = _query_to_dict(result.query)
|
||||||
assert (
|
assert "rank" not in query_dict and "world_size" not in query_dict, (
|
||||||
"rank" not in query_dict and "world_size" not in query_dict
|
f"The url: {url} has node-specific arguments(rank, world_size) already."
|
||||||
), f"The url: {url} has node-specific arguments(rank, world_size) already."
|
)
|
||||||
if rank != -1:
|
if rank != -1:
|
||||||
query_dict["rank"] = str(rank)
|
query_dict["rank"] = str(rank)
|
||||||
if world_size != -1 or world_size_opt is None:
|
if world_size != -1 or world_size_opt is None:
|
||||||
|
@ -137,13 +137,13 @@ def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
|
|||||||
with _all_gather_dict_lock:
|
with _all_gather_dict_lock:
|
||||||
if not worker_names:
|
if not worker_names:
|
||||||
worker_names = _ALL_WORKER_NAMES
|
worker_names = _ALL_WORKER_NAMES
|
||||||
assert (
|
assert worker_name in worker_names, (
|
||||||
worker_name in worker_names
|
f"{worker_name} is not expected by leader."
|
||||||
), f"{worker_name} is not expected by leader."
|
)
|
||||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||||
assert (
|
assert worker_name not in states.gathered_objects, (
|
||||||
worker_name not in states.gathered_objects
|
f"{worker_name} reported intent sequence id {sequence_id} twice. "
|
||||||
), f"{worker_name} reported intent sequence id {sequence_id} twice. "
|
)
|
||||||
states.gathered_objects[worker_name] = obj
|
states.gathered_objects[worker_name] = obj
|
||||||
if worker_names == set(states.gathered_objects.keys()):
|
if worker_names == set(states.gathered_objects.keys()):
|
||||||
states.proceed_signal.set()
|
states.proceed_signal.set()
|
||||||
@ -153,9 +153,9 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
|||||||
with _all_gather_dict_lock:
|
with _all_gather_dict_lock:
|
||||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||||
|
|
||||||
assert (
|
assert not states.proceed_signal.is_set(), (
|
||||||
not states.proceed_signal.is_set()
|
f"Termination signal sequence id {sequence_id} got set twice."
|
||||||
), f"Termination signal sequence id {sequence_id} got set twice."
|
)
|
||||||
states.gathered_objects = objects_map
|
states.gathered_objects = objects_map
|
||||||
states.proceed_signal.set()
|
states.proceed_signal.set()
|
||||||
|
|
||||||
@ -202,9 +202,9 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
|
|||||||
function blocks until all workers have received the gathered results.
|
function blocks until all workers have received the gathered results.
|
||||||
"""
|
"""
|
||||||
if not worker_names:
|
if not worker_names:
|
||||||
assert (
|
assert _ALL_WORKER_NAMES is not None, (
|
||||||
_ALL_WORKER_NAMES is not None
|
"`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
|
||||||
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
|
)
|
||||||
worker_names = _ALL_WORKER_NAMES
|
worker_names = _ALL_WORKER_NAMES
|
||||||
leader_name = min(worker_names)
|
leader_name = min(worker_names)
|
||||||
|
|
||||||
@ -930,8 +930,7 @@ def _get_should_profile():
|
|||||||
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
|
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
|
||||||
return (
|
return (
|
||||||
torch.autograd._profiler_enabled()
|
torch.autograd._profiler_enabled()
|
||||||
and torch._C._autograd._profiler_type()
|
and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||||
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ def _to_device(device: DeviceType) -> torch.device:
|
|||||||
|
|
||||||
|
|
||||||
def _to_device_map(
|
def _to_device_map(
|
||||||
device_map: dict[DeviceType, DeviceType]
|
device_map: dict[DeviceType, DeviceType],
|
||||||
) -> dict[torch.device, torch.device]:
|
) -> dict[torch.device, torch.device]:
|
||||||
full_device_map: dict[torch.device, torch.device] = {}
|
full_device_map: dict[torch.device, torch.device] = {}
|
||||||
reverse_map: dict[torch.device, torch.device] = {}
|
reverse_map: dict[torch.device, torch.device] = {}
|
||||||
@ -127,7 +127,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
|||||||
>>> options = TensorPipeRpcBackendOptions(
|
>>> options = TensorPipeRpcBackendOptions(
|
||||||
>>> num_worker_threads=8,
|
>>> num_worker_threads=8,
|
||||||
>>> device_maps={"worker1": {0: 1}}
|
>>> device_maps={"worker1": {0: 1}}
|
||||||
>>> # maps worker0's cuda:0 to worker1's cuda:1
|
>>> # maps worker0's cuda:0 to worker1's cuda:1
|
||||||
>>> )
|
>>> )
|
||||||
>>> options.set_device_map("worker1", {1: 2})
|
>>> options.set_device_map("worker1", {1: 2})
|
||||||
>>> # maps worker0's cuda:1 to worker1's cuda:2
|
>>> # maps worker0's cuda:1 to worker1's cuda:2
|
||||||
|
@ -63,10 +63,14 @@ class _server_process_global_profile(profile):
|
|||||||
>>> import torch.distributed.rpc as rpc
|
>>> import torch.distributed.rpc as rpc
|
||||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||||
>>> x, y = torch.tensor(1), torch.tensor(2)
|
>>> x, y = torch.tensor(1), torch.tensor(2)
|
||||||
>>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
|
>>> outer_profile_rref = rpc.remote(
|
||||||
|
... dst_worker_name, rpc._server_process_global_profile
|
||||||
|
... )
|
||||||
>>> outer_profile_rref.rpc_sync().__enter__()
|
>>> outer_profile_rref.rpc_sync().__enter__()
|
||||||
>>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
|
>>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
|
||||||
>>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
|
>>> inner_profile_rref = rpc.remote(
|
||||||
|
... dst_worker_name, rpc._server_process_global_profile
|
||||||
|
... )
|
||||||
>>> inner_profile_rref.rpc_sync().__enter__()
|
>>> inner_profile_rref.rpc_sync().__enter__()
|
||||||
>>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
|
>>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
|
||||||
>>> inner_profile_rref.rpc_sync().__exit__(None, None, None)
|
>>> inner_profile_rref.rpc_sync().__exit__(None, None, None)
|
||||||
|
@ -289,9 +289,9 @@ Important Notices
|
|||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> # xdoctest: +SKIP("stub")
|
>>> # xdoctest: +SKIP("stub")
|
||||||
>>> import torch.distributed as dist
|
>>> import torch.distributed as dist
|
||||||
>>> dist.init_process_group(backend="gloo|nccl")
|
>>> dist.init_process_group(backend="gloo|nccl")
|
||||||
|
|
||||||
3. In your training program, you can either use regular distributed functions
|
3. In your training program, you can either use regular distributed functions
|
||||||
or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
|
or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
|
||||||
@ -302,9 +302,9 @@ Important Notices
|
|||||||
::
|
::
|
||||||
|
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model,
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
device_ids=[local_rank],
|
model, device_ids=[local_rank], output_device=local_rank
|
||||||
output_device=local_rank)
|
)
|
||||||
|
|
||||||
Please ensure that ``device_ids`` argument is set to be the only GPU device id
|
Please ensure that ``device_ids`` argument is set to be the only GPU device id
|
||||||
that your code will be operating on. This is generally the local rank of the
|
that your code will be operating on. This is generally the local rank of the
|
||||||
@ -331,17 +331,18 @@ utility
|
|||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
load_checkpoint(checkpoint_path)
|
load_checkpoint(checkpoint_path)
|
||||||
initialize()
|
initialize()
|
||||||
train()
|
train()
|
||||||
|
|
||||||
def train():
|
|
||||||
for batch in iter(dataset):
|
|
||||||
train_step(batch)
|
|
||||||
|
|
||||||
if should_checkpoint:
|
def train():
|
||||||
save_checkpoint(checkpoint_path)
|
for batch in iter(dataset):
|
||||||
|
train_step(batch)
|
||||||
|
|
||||||
|
if should_checkpoint:
|
||||||
|
save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
9. (Recommended) On worker errors, this tool will summarize the details of the error
|
9. (Recommended) On worker errors, this tool will summarize the details of the error
|
||||||
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
|
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
|
||||||
@ -353,17 +354,19 @@ utility
|
|||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
from torch.distributed.elastic.multiprocessing.errors import record
|
from torch.distributed.elastic.multiprocessing.errors import record
|
||||||
|
|
||||||
@record
|
|
||||||
def main():
|
|
||||||
# do train
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@record
|
||||||
main()
|
def main():
|
||||||
|
# do train
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -297,9 +297,9 @@ class DTensor(torch.Tensor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
|
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
|
||||||
assert (
|
assert flatten_spec is not None, (
|
||||||
flatten_spec is not None
|
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||||
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
|
)
|
||||||
local_tensor = inner_tensors["_local_tensor"]
|
local_tensor = inner_tensors["_local_tensor"]
|
||||||
spec, requires_grad = flatten_spec
|
spec, requires_grad = flatten_spec
|
||||||
unflatten_tensor_meta = TensorMeta(
|
unflatten_tensor_meta = TensorMeta(
|
||||||
@ -694,9 +694,7 @@ def distribute_tensor(
|
|||||||
xla_distribute_tensor,
|
xla_distribute_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return xla_distribute_tensor(
|
return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value]
|
||||||
tensor, device_mesh, placements
|
|
||||||
) # type:ignore[return-value]
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
msg = "To use DTensor API with xla, you must install the torch_xla package!"
|
msg = "To use DTensor API with xla, you must install the torch_xla package!"
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg) from e
|
||||||
@ -930,7 +928,9 @@ def distribute_module(
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
|
module.register_forward_pre_hook(
|
||||||
|
lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg]
|
||||||
|
)
|
||||||
elif num_args == 3:
|
elif num_args == 3:
|
||||||
# input_fn takes in module, inputs, device mesh
|
# input_fn takes in module, inputs, device mesh
|
||||||
module.register_forward_pre_hook(
|
module.register_forward_pre_hook(
|
||||||
@ -990,9 +990,9 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
|||||||
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
|
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
|
||||||
|
|
||||||
# check device_mesh againts placements
|
# check device_mesh againts placements
|
||||||
assert device_mesh.ndim == len(
|
assert device_mesh.ndim == len(placements), (
|
||||||
placements
|
"mesh dimension does not match the length of placements"
|
||||||
), "mesh dimension does not match the length of placements"
|
)
|
||||||
|
|
||||||
assert kwargs["layout"] == torch.strided, "layout value not supported!"
|
assert kwargs["layout"] == torch.strided, "layout value not supported!"
|
||||||
torch_stride = torch._prims_common.make_contiguous_strides_for(size)
|
torch_stride = torch._prims_common.make_contiguous_strides_for(size)
|
||||||
|
@ -75,7 +75,8 @@ def found_inf_reduce_handler(
|
|||||||
) -> None:
|
) -> None:
|
||||||
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
||||||
local_tensor_args = pytree.tree_unflatten(
|
local_tensor_args = pytree.tree_unflatten(
|
||||||
cast(list[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type]
|
cast(list[object], op_info.local_args),
|
||||||
|
op_info.args_tree_spec, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
local_tensor_args = cast(tuple[object, ...], local_tensor_args)
|
local_tensor_args = cast(tuple[object, ...], local_tensor_args)
|
||||||
op_call(*local_tensor_args, **op_info.local_kwargs)
|
op_call(*local_tensor_args, **op_info.local_kwargs)
|
||||||
@ -200,8 +201,9 @@ class OpDispatcher:
|
|||||||
# did not already construct one
|
# did not already construct one
|
||||||
random._rng_tracker = random.OffsetBasedRNGTracker(mesh)
|
random._rng_tracker = random.OffsetBasedRNGTracker(mesh)
|
||||||
|
|
||||||
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
|
first_arg, first_local_arg = (
|
||||||
torch.Tensor, local_tensor_args[0]
|
cast(dtensor.DTensor, args[0]),
|
||||||
|
cast(torch.Tensor, local_tensor_args[0]),
|
||||||
)
|
)
|
||||||
rng_context = (
|
rng_context = (
|
||||||
random._rng_tracker._distribute_region(first_arg._spec)
|
random._rng_tracker._distribute_region(first_arg._spec)
|
||||||
@ -422,18 +424,18 @@ class OpDispatcher:
|
|||||||
def wrap(res: object, spec: OutputSpecType) -> object:
|
def wrap(res: object, spec: OutputSpecType) -> object:
|
||||||
if isinstance(res, torch.Tensor):
|
if isinstance(res, torch.Tensor):
|
||||||
if spec is not None:
|
if spec is not None:
|
||||||
assert isinstance(
|
assert isinstance(spec, DTensorSpec), (
|
||||||
spec, DTensorSpec
|
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
||||||
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
)
|
||||||
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
|
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
|
||||||
else:
|
else:
|
||||||
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
|
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
|
||||||
assert res.ndim == 0, "output tensor should be scalar!"
|
assert res.ndim == 0, "output tensor should be scalar!"
|
||||||
return res
|
return res
|
||||||
elif isinstance(res, (list, tuple)):
|
elif isinstance(res, (list, tuple)):
|
||||||
assert spec is not None and isinstance(
|
assert spec is not None and isinstance(spec, (list, tuple)), (
|
||||||
spec, (list, tuple)
|
f"output spec does not match with output! Expected list/tuple, got {spec}."
|
||||||
), f"output spec does not match with output! Expected list/tuple, got {spec}."
|
)
|
||||||
res_list = []
|
res_list = []
|
||||||
for e, s in zip(res, spec):
|
for e, s in zip(res, spec):
|
||||||
res_list.append(OpDispatcher.wrap(e, s))
|
res_list.append(OpDispatcher.wrap(e, s))
|
||||||
|
@ -152,9 +152,9 @@ class OpStrategy(StrategyType):
|
|||||||
if isinstance(output_spec, DTensorSpec):
|
if isinstance(output_spec, DTensorSpec):
|
||||||
return output_spec.mesh.shape
|
return output_spec.mesh.shape
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(output_spec, tuple), (
|
||||||
output_spec, tuple
|
"found no DTensorSpec in the OpStrategy!"
|
||||||
), "found no DTensorSpec in the OpStrategy!"
|
)
|
||||||
assert output_spec[0] is not None
|
assert output_spec[0] is not None
|
||||||
return output_spec[0].mesh.shape
|
return output_spec[0].mesh.shape
|
||||||
|
|
||||||
|
@ -63,9 +63,9 @@ class EinsumDims:
|
|||||||
if is_batch_dim:
|
if is_batch_dim:
|
||||||
batch_dims.append(dim_char)
|
batch_dims.append(dim_char)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert len(input_dims) == 2, (
|
||||||
len(input_dims) == 2
|
"free dimension only supported for two inputs!"
|
||||||
), "free dimension only supported for two inputs!"
|
)
|
||||||
lhs, rhs = input_dims
|
lhs, rhs = input_dims
|
||||||
if dim_char in lhs:
|
if dim_char in lhs:
|
||||||
lhs_out_only_dims.append(dim_char)
|
lhs_out_only_dims.append(dim_char)
|
||||||
|
@ -89,9 +89,9 @@ class _MaskPartial(Partial):
|
|||||||
# override parent logic to perform partial mask for embedding
|
# override parent logic to perform partial mask for embedding
|
||||||
num_chunks = mesh.size(mesh_dim)
|
num_chunks = mesh.size(mesh_dim)
|
||||||
# get local shard size and offset on the embedding_dim
|
# get local shard size and offset on the embedding_dim
|
||||||
assert (
|
assert self.offset_shape is not None, (
|
||||||
self.offset_shape is not None
|
"offset_shape needs to be set for _MaskPartial"
|
||||||
), "offset_shape needs to be set for _MaskPartial"
|
)
|
||||||
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
|
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
|
||||||
self.offset_shape[self.offset_dim],
|
self.offset_shape[self.offset_dim],
|
||||||
num_chunks,
|
num_chunks,
|
||||||
|
@ -994,9 +994,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
|
|||||||
)
|
)
|
||||||
output_specs_list.append(weight_out_spec if output_mask[1] else None)
|
output_specs_list.append(weight_out_spec if output_mask[1] else None)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert output_mask[1] is False, (
|
||||||
output_mask[1] is False
|
"output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
|
||||||
), "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
|
)
|
||||||
output_specs_list.append(None)
|
output_specs_list.append(None)
|
||||||
|
|
||||||
# arg: bias
|
# arg: bias
|
||||||
@ -1020,9 +1020,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
|
|||||||
)
|
)
|
||||||
output_specs_list.append(bias_out_spec if output_mask[2] else None)
|
output_specs_list.append(bias_out_spec if output_mask[2] else None)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert output_mask[2] is False, (
|
||||||
output_mask[2] is False
|
"output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
|
||||||
), "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
|
)
|
||||||
output_specs_list.append(None)
|
output_specs_list.append(None)
|
||||||
|
|
||||||
out_tuple_strategy.strategies.append(
|
out_tuple_strategy.strategies.append(
|
||||||
|
@ -155,9 +155,9 @@ def _scaled_mm_like_strategy(
|
|||||||
assert isinstance(scale_mat2_strategy, OpStrategy)
|
assert isinstance(scale_mat2_strategy, OpStrategy)
|
||||||
# TODO: add support for these later
|
# TODO: add support for these later
|
||||||
assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias"
|
assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias"
|
||||||
assert (
|
assert scale_result_strategy is None, (
|
||||||
scale_result_strategy is None
|
"_scaled_mm on DTensors doesn't support scale_result"
|
||||||
), "_scaled_mm on DTensors doesn't support scale_result"
|
)
|
||||||
# generate all possible strategies for mm
|
# generate all possible strategies for mm
|
||||||
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
|
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
|
||||||
# filter out invalid strategies and associate costs
|
# filter out invalid strategies and associate costs
|
||||||
|
@ -445,9 +445,9 @@ def pointwise_strategy(
|
|||||||
|
|
||||||
followed_strategy = op_schema.args_schema[max_shards_strategy_index]
|
followed_strategy = op_schema.args_schema[max_shards_strategy_index]
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(followed_strategy, OpStrategy), (
|
||||||
followed_strategy, OpStrategy
|
f"no strategy to follow for {op_schema}!"
|
||||||
), f"no strategy to follow for {op_schema}!"
|
)
|
||||||
return common_pointwise_strategy(
|
return common_pointwise_strategy(
|
||||||
mesh, op_schema.args_schema, followed_strategy, linearity
|
mesh, op_schema.args_schema, followed_strategy, linearity
|
||||||
)
|
)
|
||||||
|
@ -254,9 +254,9 @@ def dim_movedim(
|
|||||||
|
|
||||||
def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
|
def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
|
||||||
sizes = normalize_sizes(sizes)
|
sizes = normalize_sizes(sizes)
|
||||||
assert (
|
assert len(sizes) >= ndim, (
|
||||||
len(sizes) >= ndim
|
f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
|
||||||
), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
|
)
|
||||||
pad = len(sizes) - ndim
|
pad = len(sizes) - ndim
|
||||||
return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
|
return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
|
||||||
Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
|
Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
|
||||||
@ -275,9 +275,9 @@ def infer_size(total_size: int, sizes: Shape) -> Shape:
|
|||||||
if infers:
|
if infers:
|
||||||
size = -size
|
size = -size
|
||||||
missing_size = total_size // size
|
missing_size = total_size // size
|
||||||
assert (
|
assert total_size % size == 0, (
|
||||||
total_size % size == 0
|
f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
|
||||||
), f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
|
)
|
||||||
return tuple(s if s != -1 else missing_size for s in sizes)
|
return tuple(s if s != -1 else missing_size for s in sizes)
|
||||||
assert size == total_size, f"sizes do not match {total_size} vs {size}"
|
assert size == total_size, f"sizes do not match {total_size} vs {size}"
|
||||||
return sizes
|
return sizes
|
||||||
@ -538,9 +538,9 @@ def propagate_shape_and_sharding(
|
|||||||
for size, shard in zip(mesh_sizes, input_src_placements):
|
for size, shard in zip(mesh_sizes, input_src_placements):
|
||||||
if isinstance(shard, Shard) and shard.dim == in_dim:
|
if isinstance(shard, Shard) and shard.dim == in_dim:
|
||||||
submesh_size *= size
|
submesh_size *= size
|
||||||
assert (
|
assert out_size % submesh_size == 0, (
|
||||||
out_size % submesh_size == 0
|
f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
|
||||||
), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
|
)
|
||||||
|
|
||||||
# we will only shard our first component of the split
|
# we will only shard our first component of the split
|
||||||
return in_dim if cmd.split_id == 0 else None
|
return in_dim if cmd.split_id == 0 else None
|
||||||
|
@ -45,7 +45,7 @@ def register_prop_rule(
|
|||||||
# pyre-fixme[3]: Return type must be annotated.
|
# pyre-fixme[3]: Return type must be annotated.
|
||||||
# pyre-fixme[2]: Parameter must be annotated.
|
# pyre-fixme[2]: Parameter must be annotated.
|
||||||
def wrapper(
|
def wrapper(
|
||||||
impl: Callable[[OpSchema], OutputSharding]
|
impl: Callable[[OpSchema], OutputSharding],
|
||||||
) -> Callable[[OpSchema], OutputSharding]:
|
) -> Callable[[OpSchema], OutputSharding]:
|
||||||
overloads = op if isinstance(op, list) else [op]
|
overloads = op if isinstance(op, list) else [op]
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
@ -102,7 +102,7 @@ def register_op_strategy(
|
|||||||
|
|
||||||
|
|
||||||
def as_list(
|
def as_list(
|
||||||
x: Union[list[object], object]
|
x: Union[list[object], object],
|
||||||
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
||||||
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
||||||
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
|
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
|
||||||
|
@ -231,9 +231,9 @@ def redistribute_local_tensor(
|
|||||||
local_tensor, device_mesh, i, my_coordinate[i]
|
local_tensor, device_mesh, i, my_coordinate[i]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert current.is_shard(), (
|
||||||
current.is_shard()
|
f"Current placement should be shard but found {current}"
|
||||||
), f"Current placement should be shard but found {current}"
|
)
|
||||||
shard_spec = cast(Shard, current)
|
shard_spec = cast(Shard, current)
|
||||||
if shard_spec.dim != target_placement.dim:
|
if shard_spec.dim != target_placement.dim:
|
||||||
new_local_tensor = shard_spec._to_new_shard_dim(
|
new_local_tensor = shard_spec._to_new_shard_dim(
|
||||||
|
@ -487,9 +487,9 @@ class ShardingPropagator:
|
|||||||
|
|
||||||
strategy_costs: list[float] = []
|
strategy_costs: list[float] = []
|
||||||
for strtg in strategy.strategies:
|
for strtg in strategy.strategies:
|
||||||
assert (
|
assert strtg.redistribute_cost is not None, (
|
||||||
strtg.redistribute_cost is not None
|
"must set redistribute cost each strategy!"
|
||||||
), "must set redistribute cost each strategy!"
|
)
|
||||||
redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
|
redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
|
||||||
strategy_costs.append(redistribute_cost)
|
strategy_costs.append(redistribute_cost)
|
||||||
|
|
||||||
|
@ -73,9 +73,9 @@ def compute_local_shape_and_global_offset(
|
|||||||
if isinstance(placement, Shard):
|
if isinstance(placement, Shard):
|
||||||
shard_dim = placement.dim
|
shard_dim = placement.dim
|
||||||
local_offset = [0] * len(global_shape)
|
local_offset = [0] * len(global_shape)
|
||||||
assert shard_dim < len(
|
assert shard_dim < len(local_shape), (
|
||||||
local_shape
|
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||||
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
)
|
||||||
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
||||||
local_shape[shard_dim],
|
local_shape[shard_dim],
|
||||||
mesh_dim_size,
|
mesh_dim_size,
|
||||||
@ -141,16 +141,15 @@ def compute_local_shape_and_global_offset(
|
|||||||
|
|
||||||
if isinstance(placement, _StridedShard):
|
if isinstance(placement, _StridedShard):
|
||||||
strided_part_seen[shard_dim] = True
|
strided_part_seen[shard_dim] = True
|
||||||
shard_idx_stride_by_mesh_dim[shard_dim][
|
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
|
||||||
idx
|
num_shards_by_tensor_dim[shard_dim]
|
||||||
] = num_shards_by_tensor_dim[shard_dim] // (
|
// (placement.split_factor * mesh_dim_size)
|
||||||
placement.split_factor * mesh_dim_size
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
|
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
|
||||||
shard_idx_stride_by_mesh_dim[shard_dim][
|
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
|
||||||
idx
|
num_shards_by_tensor_dim[shard_dim]
|
||||||
] = num_shards_by_tensor_dim[shard_dim]
|
)
|
||||||
|
|
||||||
shard_idx = [
|
shard_idx = [
|
||||||
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
|
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
|
||||||
@ -205,9 +204,9 @@ def compute_global_tensor_info(
|
|||||||
)
|
)
|
||||||
shard_dim = shard_placement.dim
|
shard_dim = shard_placement.dim
|
||||||
|
|
||||||
assert (
|
assert shard_dim < tensor.ndim, (
|
||||||
shard_dim < tensor.ndim
|
f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
|
||||||
), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
|
)
|
||||||
|
|
||||||
local_dim_size = tensor_shape[shard_dim]
|
local_dim_size = tensor_shape[shard_dim]
|
||||||
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
|
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
|
||||||
|
@ -283,9 +283,9 @@ class CommDebugMode(TorchDispatchMode):
|
|||||||
"module_type" in self.advanced_module_tracker.module_helper_dict[fqn]
|
"module_type" in self.advanced_module_tracker.module_helper_dict[fqn]
|
||||||
and include_module_data
|
and include_module_data
|
||||||
):
|
):
|
||||||
json_dict[
|
json_dict["module_type"] = (
|
||||||
"module_type"
|
self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
|
||||||
] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
|
)
|
||||||
|
|
||||||
if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
|
if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
|
||||||
for (
|
for (
|
||||||
@ -659,9 +659,9 @@ class CommDebugMode(TorchDispatchMode):
|
|||||||
operation_dict["is_bw"] = self.advanced_module_tracker.is_bw
|
operation_dict["is_bw"] = self.advanced_module_tracker.is_bw
|
||||||
|
|
||||||
# tracks if the operation is part of activation checkpointing
|
# tracks if the operation is part of activation checkpointing
|
||||||
operation_dict[
|
operation_dict["is_activation_checkpointing"] = (
|
||||||
"is_activation_checkpointing"
|
self.advanced_module_tracker.activation_checkpointing
|
||||||
] = self.advanced_module_tracker.activation_checkpointing
|
)
|
||||||
|
|
||||||
if any(t == DTensor for t in types):
|
if any(t == DTensor for t in types):
|
||||||
for ele in args:
|
for ele in args:
|
||||||
|
@ -108,9 +108,9 @@ def _compute_local_shape_and_global_offset(
|
|||||||
if isinstance(placement, Shard):
|
if isinstance(placement, Shard):
|
||||||
shard_dim = placement.dim
|
shard_dim = placement.dim
|
||||||
local_offset = [0] * len(global_shape)
|
local_offset = [0] * len(global_shape)
|
||||||
assert shard_dim < len(
|
assert shard_dim < len(local_shape), (
|
||||||
local_shape
|
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||||
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
)
|
||||||
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
||||||
local_shape[shard_dim],
|
local_shape[shard_dim],
|
||||||
mesh_dim_size,
|
mesh_dim_size,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
To run the example, use the following command:
|
To run the example, use the following command:
|
||||||
torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing
|
torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
@ -6,6 +6,7 @@ with intermediate activations sharded across mutliple GPUs via DTensor
|
|||||||
To run the example, use the following command:
|
To run the example, use the following command:
|
||||||
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
|
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
The following example demonstrates how to represent torchrec's embedding
|
The following example demonstrates how to represent torchrec's embedding
|
||||||
sharding with the DTensor API.
|
sharding with the DTensor API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
@ -253,22 +253,18 @@ class _AttentionOp(Protocol):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class _RingRotater(ABC):
|
class _RingRotater(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None:
|
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None:
|
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def next_buffer(self) -> torch.Tensor:
|
def next_buffer(self) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class _AllToAllRotater(_RingRotater):
|
class _AllToAllRotater(_RingRotater):
|
||||||
@ -1097,15 +1093,13 @@ class _LoadBalancer(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def shard(
|
def shard(
|
||||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def unshard(
|
def unshard(
|
||||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class _SequentialSharder(_LoadBalancer):
|
class _SequentialSharder(_LoadBalancer):
|
||||||
@ -1147,9 +1141,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
|
|||||||
def shard(
|
def shard(
|
||||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert (
|
assert cls.ROUND_ROBIN_CYCLE == 2, (
|
||||||
cls.ROUND_ROBIN_CYCLE == 2
|
"The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
||||||
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
)
|
||||||
cp_world_size = mesh.size()
|
cp_world_size = mesh.size()
|
||||||
cp_rank = mesh.get_local_rank()
|
cp_rank = mesh.get_local_rank()
|
||||||
assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0
|
assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0
|
||||||
@ -1163,9 +1157,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
|
|||||||
def unshard(
|
def unshard(
|
||||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert (
|
assert cls.ROUND_ROBIN_CYCLE == 2, (
|
||||||
cls.ROUND_ROBIN_CYCLE == 2
|
"The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
||||||
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
)
|
||||||
buffer = buffer.contiguous()
|
buffer = buffer.contiguous()
|
||||||
cp_world_size = mesh.size()
|
cp_world_size = mesh.size()
|
||||||
|
|
||||||
|
@ -113,9 +113,15 @@ def local_map(
|
|||||||
>>> device_mesh=device_mesh,
|
>>> device_mesh=device_mesh,
|
||||||
>>> )
|
>>> )
|
||||||
>>>
|
>>>
|
||||||
>>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor
|
>>> W_dt = distribute_tensor(
|
||||||
>>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor
|
... W, device_mesh, (col_wise)
|
||||||
>>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
|
... ) # col-wisely sharded W tensor
|
||||||
|
>>> X_dt = distribute_tensor(
|
||||||
|
... X, device_mesh, (row_wise)
|
||||||
|
... ) # row-wisely sharded X tensor
|
||||||
|
>>> Y_dt = local_mm_allreduce_forward(
|
||||||
|
... device_mesh, W_dt, X_dt
|
||||||
|
... ) # apply local_mm_allreduce_forward to DTensors
|
||||||
|
|
||||||
.. note:: This API is currently experimental and subject to change
|
.. note:: This API is currently experimental and subject to change
|
||||||
"""
|
"""
|
||||||
@ -151,9 +157,9 @@ def local_map(
|
|||||||
)
|
)
|
||||||
if in_placements is not None:
|
if in_placements is not None:
|
||||||
spec = in_placements[idx]
|
spec = in_placements[idx]
|
||||||
assert (
|
assert spec is not None, (
|
||||||
spec is not None
|
f"DTensor input {arg} expects placements but received {spec}!"
|
||||||
), f"DTensor input {arg} expects placements but received {spec}!"
|
)
|
||||||
|
|
||||||
if not isinstance(spec, tuple):
|
if not isinstance(spec, tuple):
|
||||||
spec = tuple(spec)
|
spec = tuple(spec)
|
||||||
@ -208,17 +214,17 @@ def local_map(
|
|||||||
)
|
)
|
||||||
for out, spec in zip(flat_out, out_placements_tuple):
|
for out, spec in zip(flat_out, out_placements_tuple):
|
||||||
if isinstance(out, torch.Tensor):
|
if isinstance(out, torch.Tensor):
|
||||||
assert not isinstance(
|
assert not isinstance(out, DTensor), (
|
||||||
out, DTensor
|
f"torch.Tensor output expected but received {type(out)}: {out}"
|
||||||
), f"torch.Tensor output expected but received {type(out)}: {out}"
|
)
|
||||||
|
|
||||||
flat_dist_out.append(
|
flat_dist_out.append(
|
||||||
DTensor.from_local(out, device_mesh, spec, run_check=False)
|
DTensor.from_local(out, device_mesh, spec, run_check=False)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert spec is None, (
|
||||||
spec is None
|
f"Non-tensor output {out} expects None placements but received {spec}!"
|
||||||
), f"Non-tensor output {out} expects None placements but received {spec}!"
|
)
|
||||||
|
|
||||||
flat_dist_out.append(out)
|
flat_dist_out.append(out)
|
||||||
|
|
||||||
|
@ -188,9 +188,14 @@ def _mark_sharding(
|
|||||||
"""
|
"""
|
||||||
Mark the sharding strategy for each node in the graph module.
|
Mark the sharding strategy for each node in the graph module.
|
||||||
"""
|
"""
|
||||||
placement_strategies: dict[
|
placement_strategies: dict[Node, PlacementStrategy] = (
|
||||||
Node, PlacementStrategy
|
_mark_tensor_parallel_shardings(
|
||||||
] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
|
gm,
|
||||||
|
graph_signature,
|
||||||
|
mesh,
|
||||||
|
parameter_placements,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
@ -202,9 +207,9 @@ def _mark_sharding(
|
|||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
if node.target == operator.getitem:
|
if node.target == operator.getitem:
|
||||||
input_nodes = node.all_input_nodes
|
input_nodes = node.all_input_nodes
|
||||||
assert (
|
assert len(input_nodes) == 1, (
|
||||||
len(input_nodes) == 1
|
f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
|
||||||
), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
|
)
|
||||||
arg_strategy = placement_strategies[input_nodes[0]]
|
arg_strategy = placement_strategies[input_nodes[0]]
|
||||||
placement_strategies[node] = _create_placement_strategy(
|
placement_strategies[node] = _create_placement_strategy(
|
||||||
node,
|
node,
|
||||||
|
@ -328,7 +328,9 @@ class DTensorExtensions(FSDPExtensions):
|
|||||||
self.device_handle = device_handle
|
self.device_handle = device_handle
|
||||||
# we have to use the dynamo disable this way to disable dynamo as the decorater way would
|
# we have to use the dynamo disable this way to disable dynamo as the decorater way would
|
||||||
# trigger build failure with torch deploy...
|
# trigger build failure with torch deploy...
|
||||||
self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign]
|
self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign]
|
||||||
|
self.post_unflatten_transform
|
||||||
|
)
|
||||||
|
|
||||||
def pre_flatten_transform(
|
def pre_flatten_transform(
|
||||||
self,
|
self,
|
||||||
|
@ -64,9 +64,7 @@ def input_reshard(
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def _pack_hook_tp(
|
def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401
|
||||||
mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor
|
|
||||||
) -> Any: # noqa: D401
|
|
||||||
"""Hook function called after FWD to shard input."""
|
"""Hook function called after FWD to shard input."""
|
||||||
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
|
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
|
||||||
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
|
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
|
||||||
@ -84,9 +82,7 @@ def _pack_hook_tp(
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _unpack_hook_tp(
|
def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401
|
||||||
mesh: DeviceMesh, input_reshard_dim: int, x: Any
|
|
||||||
) -> torch.Tensor: # noqa: D401
|
|
||||||
"""Hook function called before activation recomputing in BWD to restore input."""
|
"""Hook function called before activation recomputing in BWD to restore input."""
|
||||||
if (
|
if (
|
||||||
isinstance(x, DTensor)
|
isinstance(x, DTensor)
|
||||||
|
@ -38,8 +38,7 @@ class ParallelStyle(ABC):
|
|||||||
src_data_rank: Optional[int] = 0
|
src_data_rank: Optional[int] = 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ColwiseParallel(ParallelStyle):
|
class ColwiseParallel(ParallelStyle):
|
||||||
@ -467,19 +466,21 @@ class PrepareModuleInput(ParallelStyle):
|
|||||||
)
|
)
|
||||||
self.use_local_output = use_local_output
|
self.use_local_output = use_local_output
|
||||||
if self.input_layouts is not None:
|
if self.input_layouts is not None:
|
||||||
assert (
|
assert self.desired_input_layouts is not None, (
|
||||||
self.desired_input_layouts is not None
|
"desired module inputs should not be None!"
|
||||||
), "desired module inputs should not be None!"
|
)
|
||||||
assert len(self.input_layouts) == len(
|
assert len(self.input_layouts) == len(self.desired_input_layouts), (
|
||||||
self.desired_input_layouts
|
"input_layouts and desired_input_layouts should have same length!"
|
||||||
), "input_layouts and desired_input_layouts should have same length!"
|
)
|
||||||
self.with_kwargs = input_kwarg_layouts is not None
|
self.with_kwargs = input_kwarg_layouts is not None
|
||||||
self.input_kwarg_layouts = input_kwarg_layouts or {}
|
self.input_kwarg_layouts = input_kwarg_layouts or {}
|
||||||
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
|
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
|
||||||
if self.with_kwargs:
|
if self.with_kwargs:
|
||||||
assert len(self.input_kwarg_layouts) == len(
|
assert len(self.input_kwarg_layouts) == len(
|
||||||
self.desired_input_kwarg_layouts
|
self.desired_input_kwarg_layouts
|
||||||
), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
|
), (
|
||||||
|
"input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_input_arg(
|
def _prepare_input_arg(
|
||||||
self,
|
self,
|
||||||
@ -494,9 +495,9 @@ class PrepareModuleInput(ParallelStyle):
|
|||||||
# assert inp.placements[0] == input_layout
|
# assert inp.placements[0] == input_layout
|
||||||
dt_inp = input
|
dt_inp = input
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(input, torch.Tensor), (
|
||||||
input, torch.Tensor
|
"expecting input to be a torch.Tensor!"
|
||||||
), "expecting input to be a torch.Tensor!"
|
)
|
||||||
dt_inp = DTensor.from_local(
|
dt_inp = DTensor.from_local(
|
||||||
input, mesh, (input_layout,), run_check=False
|
input, mesh, (input_layout,), run_check=False
|
||||||
)
|
)
|
||||||
@ -517,9 +518,9 @@ class PrepareModuleInput(ParallelStyle):
|
|||||||
if len(inputs) != len(self.input_layouts):
|
if len(inputs) != len(self.input_layouts):
|
||||||
raise ValueError("module inputs and input_layouts should have same length!")
|
raise ValueError("module inputs and input_layouts should have same length!")
|
||||||
|
|
||||||
assert (
|
assert self.desired_input_layouts is not None, (
|
||||||
self.desired_input_layouts is not None
|
"desired module inputs should not be None!"
|
||||||
), "desired module inputs should not be None!"
|
)
|
||||||
for inp, input_layout, desired_layout in zip(
|
for inp, input_layout, desired_layout in zip(
|
||||||
inputs, self.input_layouts, self.desired_input_layouts
|
inputs, self.input_layouts, self.desired_input_layouts
|
||||||
):
|
):
|
||||||
@ -551,7 +552,9 @@ class PrepareModuleInput(ParallelStyle):
|
|||||||
with_kwargs=True,
|
with_kwargs=True,
|
||||||
) # type: ignore[misc]
|
) # type: ignore[misc]
|
||||||
else:
|
else:
|
||||||
module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg]
|
module.register_forward_pre_hook(
|
||||||
|
lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
|
||||||
|
) # type: ignore[misc, call-arg]
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
@ -611,9 +614,9 @@ class PrepareModuleOutput(ParallelStyle):
|
|||||||
else desired_output_layouts
|
else desired_output_layouts
|
||||||
)
|
)
|
||||||
self.use_local_output = use_local_output
|
self.use_local_output = use_local_output
|
||||||
assert len(self.output_layouts) == len(
|
assert len(self.output_layouts) == len(self.desired_output_layouts), (
|
||||||
self.desired_output_layouts
|
"output_layouts and desired_output_layouts should have same length!"
|
||||||
), "output_layouts and desired_output_layouts should have same length!"
|
)
|
||||||
|
|
||||||
def _prepare_out_fn(self, outputs, device_mesh):
|
def _prepare_out_fn(self, outputs, device_mesh):
|
||||||
prepared_outputs = []
|
prepared_outputs = []
|
||||||
@ -649,5 +652,7 @@ class PrepareModuleOutput(ParallelStyle):
|
|||||||
return tuple(prepared_outputs)
|
return tuple(prepared_outputs)
|
||||||
|
|
||||||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
||||||
module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg]
|
module.register_forward_hook(
|
||||||
|
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
|
||||||
|
) # type: ignore[misc, call-arg]
|
||||||
return module
|
return module
|
||||||
|
@ -83,9 +83,9 @@ class Shard(Placement):
|
|||||||
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
|
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
|
||||||
This is because collectives usually require equal size tensor inputs
|
This is because collectives usually require equal size tensor inputs
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self.dim <= tensor.ndim, (
|
||||||
self.dim <= tensor.ndim
|
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
||||||
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
)
|
||||||
|
|
||||||
# chunk tensor over dimension `dim` into n slices
|
# chunk tensor over dimension `dim` into n slices
|
||||||
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
|
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
|
||||||
@ -468,9 +468,9 @@ class _StridedShard(Shard):
|
|||||||
"""
|
"""
|
||||||
TODO: currently _StridedShard does not support padding
|
TODO: currently _StridedShard does not support padding
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self.dim <= tensor.ndim, (
|
||||||
self.dim <= tensor.ndim
|
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
||||||
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
)
|
||||||
|
|
||||||
total_split = num_chunks * self.split_factor
|
total_split = num_chunks * self.split_factor
|
||||||
assert tensor.size(self.dim) % total_split == 0, (
|
assert tensor.size(self.dim) % total_split == 0, (
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user