mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][PYFMT] migrate PYFMT for torch.{distributed,distributions}
to ruff format
(#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e160d5fd9
commit
995df34b19
@ -59,7 +59,6 @@ USE_BLACK_FILELIST = re.compile(
|
||||
# torch/[a-c]*/**
|
||||
"torch/[a-c]*/**",
|
||||
# torch/d*/**
|
||||
"torch/d*/**",
|
||||
# torch/[e-n]*/**
|
||||
"torch/[e-n]*/**",
|
||||
# torch/optim/**
|
||||
|
@ -36,11 +36,9 @@ _M = TypeVar("_M", nn.Module, list[nn.Module])
|
||||
|
||||
|
||||
class _ContractFn(Protocol, Generic[_P, _T, _TState]):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
...
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
|
||||
|
||||
def state(self, module: nn.Module) -> _TState:
|
||||
...
|
||||
def state(self, module: nn.Module) -> _TState: ...
|
||||
|
||||
|
||||
def contract(
|
||||
@ -92,7 +90,7 @@ def contract(
|
||||
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
|
||||
@wraps(state_cls) # type: ignore[arg-type]
|
||||
def inner(
|
||||
func: Callable[Concatenate[_M, _P], _M]
|
||||
func: Callable[Concatenate[_M, _P], _M],
|
||||
) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
|
||||
@wraps(func)
|
||||
def wrapper(
|
||||
@ -232,9 +230,7 @@ def contract(
|
||||
return module.__dict__.setdefault( # type: ignore[call-overload]
|
||||
STATE_KEY,
|
||||
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
|
||||
).get(
|
||||
func
|
||||
) # type: ignore[call-overload]
|
||||
).get(func) # type: ignore[call-overload]
|
||||
|
||||
wrapper.state = get_state # type: ignore[attr-defined]
|
||||
|
||||
|
@ -274,9 +274,9 @@ def reduce_scatter_tensor(
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
|
||||
assert (
|
||||
self.size(scatter_dim) % group_size == 0
|
||||
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
assert self.size(scatter_dim) % group_size == 0, (
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
)
|
||||
if scatter_dim != 0:
|
||||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||
self = torch.cat(tensor_list)
|
||||
@ -313,9 +313,9 @@ def reduce_scatter_tensor_autograd(
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
|
||||
assert (
|
||||
self.size(scatter_dim) % group_size == 0
|
||||
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
assert self.size(scatter_dim) % group_size == 0, (
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
)
|
||||
if scatter_dim != 0:
|
||||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||
self = torch.cat(tensor_list)
|
||||
@ -414,9 +414,9 @@ def reduce_scatter_tensor_coalesced(
|
||||
|
||||
assert len(scatter_dim) == len(inputs)
|
||||
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
|
||||
assert (
|
||||
tensor.size(dim) % group_size == 0
|
||||
), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
|
||||
assert tensor.size(dim) % group_size == 0, (
|
||||
f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
|
||||
)
|
||||
if dim != 0:
|
||||
tensor_list = torch.chunk(tensor, group_size, dim=dim)
|
||||
inputs[idx] = torch.cat(tensor_list)
|
||||
@ -574,6 +574,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
|
||||
return _maybe_wrap_tensor(tensor)
|
||||
"""
|
||||
|
||||
elem: torch.Tensor
|
||||
completed: bool
|
||||
|
||||
@ -726,9 +727,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
|
||||
group_size = len(rankset)
|
||||
tag = tag or c10d._get_group_tag(group)
|
||||
elif isinstance(group, DeviceMesh):
|
||||
assert (
|
||||
group.ndim == 1
|
||||
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
assert group.ndim == 1, (
|
||||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
# TODO: it should run collective in the whole mesh instead of dim 0
|
||||
tag, rankset, _ = group._dim_group_infos[0]
|
||||
group_size = len(rankset)
|
||||
@ -763,9 +764,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
|
||||
elif isinstance(group, str):
|
||||
return group
|
||||
elif isinstance(group, DeviceMesh):
|
||||
assert (
|
||||
group.ndim == 1
|
||||
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
assert group.ndim == 1, (
|
||||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
return group._dim_group_infos[0][2]
|
||||
elif isinstance(group, tuple):
|
||||
if (
|
||||
@ -837,11 +838,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
|
||||
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
|
||||
return y
|
||||
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def all_reduce_wait_compiled(y):
|
||||
torch.ops.c10d_functional.wait_tensor(y)
|
||||
return y * y
|
||||
|
||||
|
||||
x = torch.ones(1280, 1280, device="cuda") + self.rank
|
||||
# the context manager ensures that `wait_tensor(y)` will wait on the correct work object
|
||||
with allow_inflight_collective_as_graph_input_ctx():
|
||||
@ -1057,9 +1060,9 @@ def all_gather_tensor_inplace(
|
||||
tag: str = "",
|
||||
gather_dim: int = 0,
|
||||
):
|
||||
assert (
|
||||
not async_op
|
||||
), "Can't remap async version of inplace op to functional collective"
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
assert group is not None
|
||||
@ -1076,9 +1079,9 @@ def reduce_scatter_tensor_inplace(
|
||||
scatter_dim: int = 0,
|
||||
tag: str = "",
|
||||
):
|
||||
assert (
|
||||
not async_op
|
||||
), "Can't remap async version of inplace op to functional collective"
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
assert group is not None
|
||||
@ -1105,9 +1108,9 @@ def all_reduce_inplace(
|
||||
async_op: bool = False,
|
||||
tag: str = "",
|
||||
):
|
||||
assert (
|
||||
not async_op
|
||||
), "Can't remap async version of inplace op to functional collective"
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
assert group is not None
|
||||
@ -1124,9 +1127,9 @@ def all_to_all_inplace(
|
||||
async_op=False,
|
||||
tag: str = "",
|
||||
):
|
||||
assert (
|
||||
not async_op
|
||||
), "Can't remap async version of inplace op to functional collective"
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
assert group is not None
|
||||
@ -1149,12 +1152,12 @@ def all_gather_inplace(
|
||||
async_op=False,
|
||||
tag: str = "",
|
||||
):
|
||||
assert (
|
||||
not async_op
|
||||
), "Can't remap async version of inplace op to functional collective"
|
||||
assert all(
|
||||
t.size(0) == tensor.size(0) for t in tensor_list
|
||||
), "Remapping variable size all_gather is not yet supported"
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
assert all(t.size(0) == tensor.size(0) for t in tensor_list), (
|
||||
"Remapping variable size all_gather is not yet supported"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
assert group is not None
|
||||
|
@ -592,7 +592,9 @@ class ShardedTensor(ShardedTensorBase):
|
||||
assert (
|
||||
isinstance(device, torch.device)
|
||||
and device.index == torch.cuda.current_device()
|
||||
), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
|
||||
), (
|
||||
"""Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
|
||||
)
|
||||
|
||||
current_device = torch.device(torch.cuda.current_device())
|
||||
# returns a copy of ShardedTensor on CUDA current device
|
||||
@ -831,7 +833,9 @@ class ShardedTensor(ShardedTensorBase):
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
>>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4])
|
||||
>>> st = ShardedTensor._init_from_local_tensor(
|
||||
... local_tensor, sharding_spec, [2, 4]
|
||||
... )
|
||||
>>> st
|
||||
ShardedTensor(
|
||||
ShardedTensorMetadata(
|
||||
|
@ -219,9 +219,7 @@ def reshard_local_shard(
|
||||
output_tensor_size = list(st_size)
|
||||
output_tensor_size[current_sharding_dim] = sharded_dim_size
|
||||
output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
|
||||
output_tensor_list[
|
||||
placement.rank()
|
||||
] = torch.empty( # type: ignore[union-attr, index]
|
||||
output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index]
|
||||
output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
|
||||
)
|
||||
indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type]
|
||||
|
@ -16,6 +16,6 @@ with warnings.catch_warnings():
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
sys.modules[
|
||||
"torch.distributed._sharded_tensor"
|
||||
] = torch.distributed._shard.sharded_tensor
|
||||
sys.modules["torch.distributed._sharded_tensor"] = (
|
||||
torch.distributed._shard.sharded_tensor
|
||||
)
|
||||
|
@ -67,7 +67,7 @@ def _all_gather_sharded_tensor(
|
||||
|
||||
|
||||
class CompanionMismatch(Exception):
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
def _iterate_state_dict(
|
||||
@ -409,9 +409,9 @@ def _create_cpu_state_dict(
|
||||
|
||||
def unpin_memory(t):
|
||||
succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
|
||||
assert (
|
||||
succ == 0
|
||||
), f"Unpinning shared memory failed with error-code: {succ}"
|
||||
assert succ == 0, (
|
||||
f"Unpinning shared memory failed with error-code: {succ}"
|
||||
)
|
||||
|
||||
weakref.finalize(t, unpin_memory, t)
|
||||
succ = int(
|
||||
@ -421,9 +421,9 @@ def _create_cpu_state_dict(
|
||||
1, # lines up with 'cudaHostRegisterPortable'
|
||||
)
|
||||
)
|
||||
assert (
|
||||
succ == 0
|
||||
), f"Pinning shared memory failed with error-code: {succ}"
|
||||
assert succ == 0, (
|
||||
f"Pinning shared memory failed with error-code: {succ}"
|
||||
)
|
||||
return t
|
||||
elif pin_memory:
|
||||
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
|
||||
|
@ -1525,8 +1525,7 @@ if TYPE_CHECKING:
|
||||
@overload
|
||||
def empty(
|
||||
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -1535,8 +1534,7 @@ def empty(
|
||||
*,
|
||||
dtype: Optional[_dtype] = None,
|
||||
device: Optional[_device] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
def empty( # type: ignore[misc]
|
||||
|
@ -6,6 +6,7 @@ we keep the old import path starts with `_tensor` for
|
||||
backward compatibility. We will remove this folder once
|
||||
we resolve all the BC issues.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from importlib import import_module
|
||||
|
||||
|
@ -153,7 +153,7 @@ class FSDPMemTracker(MemTracker):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
fmt.display_snapshot("peak")
|
||||
fmt.display_modulewise_snapshots(depth = 3, units = "MB")
|
||||
fmt.display_modulewise_snapshots(depth=3, units="MB")
|
||||
|
||||
"""
|
||||
|
||||
|
@ -379,7 +379,7 @@ class MemTracker(TorchDispatchMode):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
mt.display_snapshot("peak")
|
||||
mt.display_modulewise_snapshots(depth = 3, units = "MiB")
|
||||
mt.display_modulewise_snapshots(depth=3, units="MiB")
|
||||
|
||||
Known Limitations:
|
||||
- The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``.
|
||||
|
@ -42,6 +42,7 @@ class ModTracker:
|
||||
def my_linear(m1, m2, bias):
|
||||
print(f"Current modules: {tracker.parents}")
|
||||
return torch.mm(m1, m2.t()) + bias
|
||||
|
||||
torch.nn.functional.linear = my_linear
|
||||
|
||||
mod(torch.rand(2, 2))
|
||||
|
@ -255,9 +255,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
||||
Tuple[Any, float]: A tuple containing the result of the function and
|
||||
the mean operation time in milliseconds.
|
||||
"""
|
||||
assert isinstance(
|
||||
cls.fake_mode, FakeTensorMode
|
||||
), "Initialize/Assign FakeTensorMode before using this function"
|
||||
assert isinstance(cls.fake_mode, FakeTensorMode), (
|
||||
"Initialize/Assign FakeTensorMode before using this function"
|
||||
)
|
||||
mean_op_time = 0.0
|
||||
if func._overloadpacket not in _VIEW_OPS:
|
||||
try:
|
||||
@ -289,9 +289,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
||||
Tuple[Any, float]: A tuple containing the result of the function and
|
||||
the mean operation time in milliseconds.
|
||||
"""
|
||||
assert (
|
||||
torch.cuda.is_available()
|
||||
), "Roofline estimation needs to access CUDA capabilities to make estimations"
|
||||
assert torch.cuda.is_available(), (
|
||||
"Roofline estimation needs to access CUDA capabilities to make estimations"
|
||||
)
|
||||
|
||||
def get_num_bytes(t: torch.Tensor) -> int:
|
||||
"""
|
||||
@ -324,9 +324,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
||||
float: The estimated compute time in nanoseconds.
|
||||
"""
|
||||
if func_packet in flop_registry:
|
||||
assert (
|
||||
len(out_dtypes) == 1
|
||||
), f"Only support single out dtype got {out_dtypes} for {func_packet}"
|
||||
assert len(out_dtypes) == 1, (
|
||||
f"Only support single out dtype got {out_dtypes} for {func_packet}"
|
||||
)
|
||||
dtype = out_dtypes.pop()
|
||||
# This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
|
||||
peak_gpu_flops = get_device_tflops(dtype) * 1e15
|
||||
@ -487,9 +487,9 @@ class RuntimeEstimator(TorchDispatchMode):
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
fake_mode = active_fake_mode()
|
||||
assert isinstance(
|
||||
fake_mode, FakeTensorMode
|
||||
), "No FakeTensorMode found, designed to used under FakeTensorMode"
|
||||
assert isinstance(fake_mode, FakeTensorMode), (
|
||||
"No FakeTensorMode found, designed to used under FakeTensorMode"
|
||||
)
|
||||
RuntimeEstimator.fake_mode = fake_mode
|
||||
self.total_runtime = 0.0
|
||||
self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0))
|
||||
|
@ -245,7 +245,7 @@ class SACEstimator(TorchDispatchMode):
|
||||
with FakeTensorMode():
|
||||
module = ...
|
||||
inp = ...
|
||||
with sac_estimator('operator-level-cost-model'):
|
||||
with sac_estimator("operator-level-cost-model"):
|
||||
output = module(inp)
|
||||
sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True)
|
||||
"""
|
||||
@ -442,9 +442,9 @@ class SACEstimator(TorchDispatchMode):
|
||||
out_storages_cpu.update(_get_untyped_storages(o))
|
||||
|
||||
# Check if there's more than 1 CUDA device
|
||||
assert (
|
||||
len(cuda_devices) <= 1
|
||||
), f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
|
||||
assert len(cuda_devices) <= 1, (
|
||||
f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
|
||||
)
|
||||
|
||||
# 2. Get the memory consumed by output
|
||||
nbytes_cuda = sum(
|
||||
@ -484,9 +484,9 @@ class SACEstimator(TorchDispatchMode):
|
||||
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
|
||||
acm_stats.sac_metadata.append(acm)
|
||||
else:
|
||||
assert (
|
||||
mod_fqn == "Global"
|
||||
), f"Module {mod_fqn} not found in AC Mod Stats"
|
||||
assert mod_fqn == "Global", (
|
||||
f"Module {mod_fqn} not found in AC Mod Stats"
|
||||
)
|
||||
self._sac_metadata.append(acm)
|
||||
|
||||
return out
|
||||
@ -979,9 +979,9 @@ class SACEstimator(TorchDispatchMode):
|
||||
|
||||
def __enter__(self) -> Self: # type: ignore[no-untyped-def]
|
||||
fake_mode = active_fake_mode()
|
||||
assert isinstance(
|
||||
fake_mode, FakeTensorMode
|
||||
), "SAC Estimator should be called in FakeTensorMode"
|
||||
assert isinstance(fake_mode, FakeTensorMode), (
|
||||
"SAC Estimator should be called in FakeTensorMode"
|
||||
)
|
||||
RuntimeEstimator.fake_mode = fake_mode
|
||||
self._mod_tracker.register_user_hooks(
|
||||
pre_fw_hook=self._pre_fw_hook,
|
||||
|
@ -38,9 +38,9 @@ def _perform_local_step(
|
||||
"""
|
||||
overlap_info = zero._overlap_info
|
||||
bucket_index = bucket.index()
|
||||
assert (
|
||||
len(zero.optim.param_groups) == 1
|
||||
), "Overlapping DDP with ZeRO only supports a single parameter group"
|
||||
assert len(zero.optim.param_groups) == 1, (
|
||||
"Overlapping DDP with ZeRO only supports a single parameter group"
|
||||
)
|
||||
|
||||
# Construct the `gradients` input for the local optimizer step, which
|
||||
# expects `None` in a list position to indicate that the corresponding
|
||||
@ -49,9 +49,9 @@ def _perform_local_step(
|
||||
gradients: list[Optional[torch.Tensor]] = [
|
||||
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
|
||||
]
|
||||
assert (
|
||||
bucket_index in overlap_info.offsets
|
||||
), f"Bucket index {bucket_index} was not assigned to rank {rank}"
|
||||
assert bucket_index in overlap_info.offsets, (
|
||||
f"Bucket index {bucket_index} was not assigned to rank {rank}"
|
||||
)
|
||||
gradients_offset = overlap_info.offsets[bucket_index]
|
||||
bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
|
||||
bucket_offset = bucket_assignment.offset
|
||||
@ -77,13 +77,13 @@ def _broadcast_bucket(
|
||||
:class:`ZeroRedundancyOptimizer` instance.
|
||||
"""
|
||||
overlap_info = zero._overlap_info
|
||||
assert (
|
||||
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
|
||||
), "`assigned_ranks_per_bucket` is not fully constructed"
|
||||
assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
|
||||
"`assigned_ranks_per_bucket` is not fully constructed"
|
||||
)
|
||||
# Sort to ensure the same ordering across ranks
|
||||
assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
|
||||
assert len(assigned_ranks) > 0, (
|
||||
f"Bucket {bucket_index} should be " "assigned to at least one rank"
|
||||
f"Bucket {bucket_index} should be assigned to at least one rank"
|
||||
)
|
||||
for assigned_rank in assigned_ranks:
|
||||
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
|
||||
@ -273,9 +273,9 @@ def hook_with_zero_step(
|
||||
rank = zero.global_rank
|
||||
|
||||
assert overlap_info.status == _OverlapStatus.INITIALIZED
|
||||
assert (
|
||||
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
|
||||
), "`assigned_ranks_per_bucket` is not fully constructed"
|
||||
assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
|
||||
"`assigned_ranks_per_bucket` is not fully constructed"
|
||||
)
|
||||
assigned_to_bucket = (
|
||||
rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
|
||||
)
|
||||
@ -288,9 +288,9 @@ def hook_with_zero_step(
|
||||
# Check that buckets are indexed incrementally starting from 0 in the
|
||||
# order of their autograd hooks firing
|
||||
if len(overlap_info.bucket_indices_seen) > 0:
|
||||
assert (
|
||||
overlap_info.bucket_indices_seen[-1] == bucket_index - 1
|
||||
), "Bucket indices are not in incremental order"
|
||||
assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, (
|
||||
"Bucket indices are not in incremental order"
|
||||
)
|
||||
else:
|
||||
assert bucket_index == 0, "Bucket indices do not start from 0"
|
||||
overlap_info.bucket_indices_seen.append(bucket_index)
|
||||
|
@ -129,7 +129,7 @@ def bf16_compress_hook(
|
||||
|
||||
|
||||
def fp16_compress_wrapper(
|
||||
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
|
||||
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
|
||||
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
|
||||
"""
|
||||
Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
|
||||
@ -167,7 +167,7 @@ def fp16_compress_wrapper(
|
||||
|
||||
|
||||
def bf16_compress_wrapper(
|
||||
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
|
||||
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
|
||||
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
|
||||
"""
|
||||
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
|
||||
|
@ -223,8 +223,7 @@ class Join:
|
||||
self._rank = dist.get_rank(self._process_group)
|
||||
self._device = device
|
||||
|
||||
def __enter__(self):
|
||||
...
|
||||
def __enter__(self): ...
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
|
@ -52,7 +52,10 @@ def average_parameters(
|
||||
|
||||
|
||||
def get_params_to_average(
|
||||
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]]
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter],
|
||||
Iterable[dict[str, torch.nn.Parameter]],
|
||||
],
|
||||
):
|
||||
"""
|
||||
Return a list of parameters that need to average.
|
||||
|
@ -550,9 +550,7 @@ def create_default_global_save_plan(
|
||||
new_item = dataclasses.replace(item, index=new_index)
|
||||
new_items.append(new_item)
|
||||
|
||||
assert (
|
||||
item.tensor_data.chunk is not None
|
||||
), f"""
|
||||
assert item.tensor_data.chunk is not None, f"""
|
||||
Cannot create MD for tensor without bounds.
|
||||
FQN: {item.index.fqn}
|
||||
"""
|
||||
|
@ -414,41 +414,33 @@ class FileSystemBase(ABC):
|
||||
@abstractmethod
|
||||
def create_stream(
|
||||
self, path: Union[str, os.PathLike], mode: str
|
||||
) -> Generator[io.IOBase, None, None]:
|
||||
...
|
||||
) -> Generator[io.IOBase, None, None]: ...
|
||||
|
||||
@abstractmethod
|
||||
def concat_path(
|
||||
self, path: Union[str, os.PathLike], suffix: str
|
||||
) -> Union[str, os.PathLike]:
|
||||
...
|
||||
) -> Union[str, os.PathLike]: ...
|
||||
|
||||
@abstractmethod
|
||||
def rename(
|
||||
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
|
||||
...
|
||||
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ...
|
||||
|
||||
@abstractmethod
|
||||
def mkdir(self, path: Union[str, os.PathLike]) -> None:
|
||||
...
|
||||
def mkdir(self, path: Union[str, os.PathLike]) -> None: ...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
...
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, path: Union[str, os.PathLike]) -> bool:
|
||||
...
|
||||
def exists(self, path: Union[str, os.PathLike]) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def rm_file(self, path: Union[str, os.PathLike]) -> None:
|
||||
...
|
||||
def rm_file(self, path: Union[str, os.PathLike]) -> None: ...
|
||||
|
||||
|
||||
class FileSystem(FileSystemBase):
|
||||
@ -512,7 +504,6 @@ class FileSystem(FileSystemBase):
|
||||
|
||||
|
||||
class _FileSystemWriter(StorageWriter):
|
||||
|
||||
"""
|
||||
Basic implementation of StorageWriter using file IO.
|
||||
|
||||
@ -800,9 +791,9 @@ class FileSystemReader(StorageReader):
|
||||
)
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
|
||||
assert (
|
||||
target_tensor.size() == tensor.size()
|
||||
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
|
||||
|
@ -135,12 +135,12 @@ def _get_state_dict_2d_layout(
|
||||
for key, value in state_dict.items():
|
||||
specs[key] = (None, value.size())
|
||||
if _is_nested_tensor(value):
|
||||
assert (
|
||||
len(value.local_shards()) == 1
|
||||
), "Cannot handle ST with multiple shards"
|
||||
assert isinstance(
|
||||
value, ShardedTensor
|
||||
), "Can only handle nested ShardedTensor"
|
||||
assert len(value.local_shards()) == 1, (
|
||||
"Cannot handle ST with multiple shards"
|
||||
)
|
||||
assert isinstance(value, ShardedTensor), (
|
||||
"Can only handle nested ShardedTensor"
|
||||
)
|
||||
shard = value.local_shards()[0]
|
||||
specs[key] = (
|
||||
shard.metadata.shard_offsets,
|
||||
|
@ -151,7 +151,7 @@ class SavePlanner(abc.ABC):
|
||||
>>> storage_meta: Optional[StorageMeta],
|
||||
>>> is_coordinator: bool,
|
||||
>>> ) -> None:
|
||||
>>> # prefix all keys with `foo_``
|
||||
>>> # prefix all keys with `foo_``
|
||||
>>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
|
||||
|
||||
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
|
||||
@ -175,8 +175,8 @@ class SavePlanner(abc.ABC):
|
||||
>>> from itertools import zip_longest
|
||||
>>> from dataclasses import replace
|
||||
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
|
||||
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
|
||||
>>> # This sample doesn't handle ShardedTensors
|
||||
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
|
||||
>>> # This sample doesn't handle ShardedTensors
|
||||
>>> def create_global_plan(self, all_plans):
|
||||
>>> iters = [iter(all_plans[0].items)] * len(all_plans)
|
||||
>>> items_per_rank = [
|
||||
@ -347,7 +347,7 @@ class LoadPlanner:
|
||||
>>> self.is_coordinator = is_coordinator
|
||||
>>>
|
||||
>>> def load_bytes(self, read_item, value):
|
||||
>>> # Remove the "foo_" prefix
|
||||
>>> # Remove the "foo_" prefix
|
||||
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
|
||||
|
||||
|
||||
|
@ -140,10 +140,12 @@ class StateDictOptions:
|
||||
@dataclass
|
||||
class _StateDictInfo(StateDictOptions):
|
||||
fqn_param_mapping: dict[
|
||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
||||
Union[str, torch.Tensor],
|
||||
Union[FQNS_T, torch.Tensor],
|
||||
] = field(default_factory=dict)
|
||||
shared_params_mapping: dict[
|
||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
||||
Union[str, torch.Tensor],
|
||||
Union[FQNS_T, torch.Tensor],
|
||||
] = field(default_factory=dict)
|
||||
submodule_prefixes: set[str] = field(default_factory=set)
|
||||
handle_model: bool = True
|
||||
@ -1140,7 +1142,9 @@ def get_state_dict(
|
||||
|
||||
|
||||
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
|
||||
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
|
||||
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
|
||||
... fsdp_model, fsdp_optim
|
||||
... )
|
||||
|
||||
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
|
||||
>>> # the asserts will fail.
|
||||
|
@ -125,7 +125,9 @@ def load(
|
||||
>>> my_model = MyModule()
|
||||
>>> optimizer = Adagrad(my_model.parameters())
|
||||
>>> model_state_dict = my_model.state_dict()
|
||||
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
|
||||
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
|
||||
... "/checkpoint/1"
|
||||
... )
|
||||
|
||||
>>> torch.distributed.checkpoint.load_state_dict(
|
||||
>>> state_dict=model_state_dict,
|
||||
|
@ -127,7 +127,9 @@ def save(
|
||||
|
||||
>>> state_dict = {"model": my_model}
|
||||
|
||||
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
|
||||
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
|
||||
... "/checkpoint/1"
|
||||
... )
|
||||
>>> torch.distributed.checkpoint.save(
|
||||
>>> state_dict=state_dict,
|
||||
>>> storage_writer=fs_storage_writer,
|
||||
@ -206,7 +208,9 @@ def async_save(
|
||||
|
||||
>>> state_dict = {"model": my_model}
|
||||
|
||||
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
|
||||
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
|
||||
... "/checkpoint/1"
|
||||
... )
|
||||
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
|
||||
>>> state_dict=state_dict,
|
||||
>>> storage_writer=fs_storage_writer,
|
||||
@ -223,7 +227,9 @@ def async_save(
|
||||
pg = process_group or _get_default_group()
|
||||
assert (
|
||||
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
|
||||
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||
), (
|
||||
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||
)
|
||||
|
||||
storage_writer = cast(
|
||||
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
|
||||
|
@ -32,7 +32,7 @@ R = TypeVar("R")
|
||||
|
||||
|
||||
def _get_failure_dict(
|
||||
results: list[Union[T, WRAPPED_EXCEPTION]]
|
||||
results: list[Union[T, WRAPPED_EXCEPTION]],
|
||||
) -> dict[int, WRAPPED_EXCEPTION]:
|
||||
return cast(
|
||||
dict[int, WRAPPED_EXCEPTION],
|
||||
|
@ -221,8 +221,12 @@ else:
|
||||
if cur_rank in mesh_nd:
|
||||
res_flattened_mesh = flattened_mesh
|
||||
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
|
||||
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined]
|
||||
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined]
|
||||
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
|
||||
res_flattened_mesh # type: ignore[possibly-undefined]
|
||||
)
|
||||
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
|
||||
flatten_dims_in_root
|
||||
) # type: ignore[possibly-undefined]
|
||||
|
||||
return res_flattened_mesh
|
||||
|
||||
@ -242,9 +246,9 @@ else:
|
||||
root_mesh = self.get_root_mesh(device_mesh)
|
||||
child_mesh_dim_names = device_mesh.mesh_dim_names
|
||||
if root_mesh and child_mesh_dim_names:
|
||||
assert (
|
||||
len(child_mesh_dim_names) == 1
|
||||
), "The submesh can only be a 1D mesh."
|
||||
assert len(child_mesh_dim_names) == 1, (
|
||||
"The submesh can only be a 1D mesh."
|
||||
)
|
||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||
return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name)
|
||||
return None
|
||||
@ -763,7 +767,9 @@ else:
|
||||
root_mesh, None
|
||||
)
|
||||
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
|
||||
dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index]
|
||||
dim_group_infos = root_to_flatten_mapping[
|
||||
mesh_dim # type: ignore[index]
|
||||
]._dim_group_infos[0][:2]
|
||||
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
|
||||
else:
|
||||
mesh_dim = (
|
||||
@ -905,9 +911,9 @@ else:
|
||||
mesh_dim = 0
|
||||
|
||||
mesh_dim_group = not_none(self.get_group(mesh_dim))
|
||||
assert isinstance(
|
||||
mesh_dim_group, ProcessGroup
|
||||
), "We expect ProcessGroup before calling `get_rank`!"
|
||||
assert isinstance(mesh_dim_group, ProcessGroup), (
|
||||
"We expect ProcessGroup before calling `get_rank`!"
|
||||
)
|
||||
return not_none(get_rank(mesh_dim_group))
|
||||
|
||||
def get_coordinate(self) -> Optional[list[int]]:
|
||||
|
@ -334,12 +334,12 @@ class Backend(str): # noqa: SLOT000
|
||||
# Allow UCC plugin if Pytorch is not built with native support.
|
||||
# TODO: remove this exception once UCC plugin is fully deprecated.
|
||||
if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()):
|
||||
assert not hasattr(
|
||||
Backend, name.upper()
|
||||
), f"{name.upper()} c10d backend already exist"
|
||||
assert (
|
||||
name.upper() not in Backend._plugins
|
||||
), f"{name.upper()} c10d backend creator function already exist"
|
||||
assert not hasattr(Backend, name.upper()), (
|
||||
f"{name.upper()} c10d backend already exist"
|
||||
)
|
||||
assert name.upper() not in Backend._plugins, (
|
||||
f"{name.upper()} c10d backend creator function already exist"
|
||||
)
|
||||
|
||||
setattr(Backend, name.upper(), name.lower())
|
||||
Backend.backend_list.append(name.lower())
|
||||
@ -1650,9 +1650,9 @@ def init_process_group(
|
||||
if "torch._dynamo" in sys.modules:
|
||||
torch._dynamo.trace_rules.clear_lru_cache()
|
||||
|
||||
assert (store is None) or (
|
||||
init_method is None
|
||||
), "Cannot specify both init_method and store."
|
||||
assert (store is None) or (init_method is None), (
|
||||
"Cannot specify both init_method and store."
|
||||
)
|
||||
|
||||
if store is not None:
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
@ -1734,7 +1734,10 @@ def init_process_group(
|
||||
)
|
||||
_update_default_pg(default_pg)
|
||||
|
||||
_world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
|
||||
_world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index]
|
||||
i: i
|
||||
for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined]
|
||||
}
|
||||
_backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
|
||||
_default_pg_init_method = init_method
|
||||
|
||||
@ -1959,9 +1962,9 @@ def _new_process_group_helper(
|
||||
if not is_nccl_available():
|
||||
raise RuntimeError("Distributed package doesn't have NCCL built in")
|
||||
if backend_options is not None:
|
||||
assert isinstance(
|
||||
backend_options, ProcessGroupNCCL.Options
|
||||
), "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||
assert isinstance(backend_options, ProcessGroupNCCL.Options), (
|
||||
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||
)
|
||||
if backend_options._timeout != timeout:
|
||||
warnings.warn(
|
||||
"backend_options._timeout was specified, "
|
||||
@ -2001,9 +2004,9 @@ def _new_process_group_helper(
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.XCCL
|
||||
else:
|
||||
assert (
|
||||
backend_str.upper() in Backend._plugins
|
||||
), f"Unknown c10d backend type {backend_str.upper()}"
|
||||
assert backend_str.upper() in Backend._plugins, (
|
||||
f"Unknown c10d backend type {backend_str.upper()}"
|
||||
)
|
||||
|
||||
backend_plugin = Backend._plugins[backend_str.upper()]
|
||||
creator_fn = backend_plugin.creator_fn
|
||||
@ -2630,8 +2633,10 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
|
||||
>>> recv_tensor = torch.randn(2, dtype=torch.float32)
|
||||
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
|
||||
>>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size)
|
||||
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
|
||||
>>> recv_op = dist.P2POp(
|
||||
... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
|
||||
... )
|
||||
>>> reqs = batch_isend_irecv([send_op, recv_op])
|
||||
>>> for req in reqs:
|
||||
>>> req.wait()
|
||||
@ -2758,7 +2763,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>> # All tensors below are of torch.int64 type.
|
||||
>>> # We have 2 process groups, 2 ranks.
|
||||
>>> device = torch.device(f'cuda:{rank}')
|
||||
>>> device = torch.device(f"cuda:{rank}")
|
||||
>>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
|
||||
>>> tensor
|
||||
tensor([1, 2], device='cuda:0') # Rank 0
|
||||
@ -2770,7 +2775,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
|
||||
>>> # All tensors below are of torch.cfloat type.
|
||||
>>> # We have 2 process groups, 2 ranks.
|
||||
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
|
||||
>>> tensor = torch.tensor(
|
||||
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
|
||||
... ) + 2 * rank * (1 + 1j)
|
||||
>>> tensor
|
||||
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
|
||||
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
|
||||
@ -3380,9 +3387,9 @@ def recv_object_list(
|
||||
)
|
||||
|
||||
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
|
||||
assert (
|
||||
rank_sizes == rank_objects
|
||||
), "Mismatch in return ranks for object sizes and objects."
|
||||
assert rank_sizes == rank_objects, (
|
||||
"Mismatch in return ranks for object sizes and objects."
|
||||
)
|
||||
# Deserialize objects using their stored sizes.
|
||||
offset = 0
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
@ -3673,8 +3680,10 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
|
||||
>>> # xdoctest: +SKIP("need process group init")
|
||||
>>> # All tensors below are of torch.int64 dtype.
|
||||
>>> # We have 2 process groups, 2 ranks.
|
||||
>>> device = torch.device(f'cuda:{rank}')
|
||||
>>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)]
|
||||
>>> device = torch.device(f"cuda:{rank}")
|
||||
>>> tensor_list = [
|
||||
... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)
|
||||
... ]
|
||||
>>> tensor_list
|
||||
[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
|
||||
[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
|
||||
@ -3689,11 +3698,15 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
|
||||
|
||||
>>> # All tensors below are of torch.cfloat dtype.
|
||||
>>> # We have 2 process groups, 2 ranks.
|
||||
>>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)]
|
||||
>>> tensor_list = [
|
||||
... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)
|
||||
... ]
|
||||
>>> tensor_list
|
||||
[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
|
||||
[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
|
||||
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
|
||||
>>> tensor = torch.tensor(
|
||||
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
|
||||
... ) + 2 * rank * (1 + 1j)
|
||||
>>> tensor
|
||||
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
|
||||
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
|
||||
@ -3769,7 +3782,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
|
||||
>>> # xdoctest: +SKIP("need process group init")
|
||||
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
|
||||
>>> # We have two ranks.
|
||||
>>> device = torch.device(f'cuda:{rank}')
|
||||
>>> device = torch.device(f"cuda:{rank}")
|
||||
>>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
|
||||
>>> tensor_in
|
||||
tensor([1, 2], device='cuda:0') # Rank 0
|
||||
@ -3969,8 +3982,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
)
|
||||
elif gather_list:
|
||||
raise ValueError(
|
||||
"Argument ``gather_list`` must NOT be specified "
|
||||
"on non-destination ranks."
|
||||
"Argument ``gather_list`` must NOT be specified on non-destination ranks."
|
||||
)
|
||||
|
||||
|
||||
@ -4141,8 +4153,7 @@ def scatter(
|
||||
else:
|
||||
if scatter_list:
|
||||
raise ValueError(
|
||||
"Argument ``scatter_list`` must NOT be specified "
|
||||
"on non-source ranks."
|
||||
"Argument ``scatter_list`` must NOT be specified on non-source ranks."
|
||||
)
|
||||
input_tensors = []
|
||||
output_tensors = [tensor]
|
||||
@ -4225,7 +4236,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
|
||||
>>> # xdoctest: +SKIP("need process group init")
|
||||
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
|
||||
>>> # We have two ranks.
|
||||
>>> device = torch.device(f'cuda:{rank}')
|
||||
>>> device = torch.device(f"cuda:{rank}")
|
||||
>>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
|
||||
>>> # Input in concatenation form
|
||||
>>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
|
||||
@ -4381,7 +4392,7 @@ def all_to_all_single(
|
||||
|
||||
>>> # Essentially, it is similar to following operation:
|
||||
>>> scatter_list = list(input.chunk(world_size))
|
||||
>>> gather_list = list(output.chunk(world_size))
|
||||
>>> gather_list = list(output.chunk(world_size))
|
||||
>>> for i in range(world_size):
|
||||
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
|
||||
|
||||
@ -4411,7 +4422,9 @@ def all_to_all_single(
|
||||
|
||||
|
||||
>>> # Another example with tensors of torch.cfloat type.
|
||||
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
|
||||
>>> input = torch.tensor(
|
||||
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
|
||||
... ) + 4 * rank * (1 + 1j)
|
||||
>>> input
|
||||
tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
|
||||
tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
|
||||
@ -4510,7 +4523,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
||||
|
||||
>>> # Essentially, it is similar to following operation:
|
||||
>>> scatter_list = input
|
||||
>>> gather_list = output
|
||||
>>> gather_list = output
|
||||
>>> for i in range(world_size):
|
||||
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
|
||||
|
||||
@ -4544,7 +4557,9 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
||||
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3
|
||||
|
||||
>>> # Another example with tensors of torch.cfloat type.
|
||||
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
|
||||
>>> input = torch.tensor(
|
||||
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
|
||||
... ) + 4 * rank * (1 + 1j)
|
||||
>>> input = list(input.chunk(4))
|
||||
>>> input
|
||||
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
|
||||
@ -4882,9 +4897,9 @@ def split_group(
|
||||
backend_config = BackendConfig(backend)
|
||||
|
||||
if pg_options is not None:
|
||||
assert isinstance(
|
||||
pg_options, ProcessGroupNCCL.Options
|
||||
), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
|
||||
assert isinstance(pg_options, ProcessGroupNCCL.Options), (
|
||||
"Expected pg_options argument to be of type ProcessGroupNCCL.Options"
|
||||
)
|
||||
else:
|
||||
# default pg_options same as the parent process group
|
||||
pg_options = parent_backend.options
|
||||
@ -5086,9 +5101,9 @@ def _new_group_with_tag(
|
||||
if device_id is None:
|
||||
device_id = default_pg.bound_device_id
|
||||
elif default_pg.bound_device_id is not None:
|
||||
assert (
|
||||
device_id == default_pg.bound_device_id
|
||||
), "Mismatched bound device between new pg and the default pg."
|
||||
assert device_id == default_pg.bound_device_id, (
|
||||
"Mismatched bound device between new pg and the default pg."
|
||||
)
|
||||
default_backend, default_store = _world.pg_map[default_pg]
|
||||
global_rank = default_pg.rank()
|
||||
global_world_size = default_pg.size()
|
||||
@ -5408,9 +5423,9 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
|
||||
def _find_or_create_pg_by_ranks_and_tag(
|
||||
tag: str, ranks: list[int], stride: int
|
||||
) -> ProcessGroup:
|
||||
assert (
|
||||
len(ranks) % stride == 0
|
||||
), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||
assert len(ranks) % stride == 0, (
|
||||
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||
)
|
||||
|
||||
my_rank = get_rank()
|
||||
my_ranks = None
|
||||
|
@ -40,8 +40,9 @@ def worker_main() -> Generator[None, None, None]:
|
||||
def main():
|
||||
pass
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
with ExitStack() as stack:
|
||||
|
@ -14,7 +14,10 @@ Example of usage:
|
||||
::
|
||||
|
||||
from torch.distributed.elastic import events
|
||||
event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
|
||||
|
||||
event = events.Event(
|
||||
name="test_event", source=events.EventSource.WORKER, metadata={...}
|
||||
)
|
||||
events.get_logging_handler(destination="console").info(event)
|
||||
|
||||
"""
|
||||
|
@ -52,11 +52,12 @@ The example below measures the latency for the ``calculate()`` function.
|
||||
metrics.configure(metrics.NullMetricsHandler())
|
||||
metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
|
||||
|
||||
|
||||
def my_method():
|
||||
start = time.time()
|
||||
calculate()
|
||||
end = time.time()
|
||||
metrics.put_metric("calculate_latency", int(end-start), "my_module")
|
||||
start = time.time()
|
||||
calculate()
|
||||
end = time.time()
|
||||
metrics.put_metric("calculate_latency", int(end - start), "my_module")
|
||||
|
||||
You may also use the torch.distributed.elastic.metrics.prof` decorator
|
||||
to conveniently and succinctly profile functions
|
||||
@ -70,15 +71,16 @@ to conveniently and succinctly profile functions
|
||||
metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
|
||||
metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
|
||||
|
||||
|
||||
@metrics.prof
|
||||
def foo():
|
||||
pass
|
||||
pass
|
||||
|
||||
class Bar():
|
||||
|
||||
@metrics.prof
|
||||
def baz():
|
||||
pass
|
||||
class Bar:
|
||||
@metrics.prof
|
||||
def baz():
|
||||
pass
|
||||
|
||||
``@metrics.prof`` will publish the following metrics
|
||||
::
|
||||
@ -102,8 +104,8 @@ console.
|
||||
|
||||
import torch.distributed.elastic.metrics as metrics
|
||||
|
||||
metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic")
|
||||
metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app")
|
||||
metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic")
|
||||
metrics.configure(metrics.ConsoleMetricHandler(), group="my_app")
|
||||
|
||||
**Writing a Custom Metric Handler**:
|
||||
|
||||
@ -117,13 +119,15 @@ Below is a toy example that prints the metrics to ``stdout``
|
||||
|
||||
import torch.distributed.elastic.metrics as metrics
|
||||
|
||||
|
||||
class StdoutMetricHandler(metrics.MetricHandler):
|
||||
def emit(self, metric_data):
|
||||
ts = metric_data.timestamp
|
||||
group = metric_data.group_name
|
||||
name = metric_data.name
|
||||
value = metric_data.value
|
||||
print(f"[{ts}][{group}]: {name}={value}")
|
||||
def emit(self, metric_data):
|
||||
ts = metric_data.timestamp
|
||||
group = metric_data.group_name
|
||||
name = metric_data.name
|
||||
value = metric_data.value
|
||||
print(f"[{ts}][{group}]: {name}={value}")
|
||||
|
||||
|
||||
metrics.configure(StdoutMetricHandler(), group="my_app")
|
||||
|
||||
|
@ -123,6 +123,7 @@ def prof(fn=None, group: str = "torchelastic"):
|
||||
def x():
|
||||
pass
|
||||
|
||||
|
||||
@metrics.prof(group="agent")
|
||||
def y():
|
||||
pass
|
||||
|
@ -20,22 +20,23 @@ Usage 1: Launching two trainers as a function
|
||||
|
||||
from torch.distributed.elastic.multiprocessing import Std, start_processes
|
||||
|
||||
|
||||
def trainer(a, b, c):
|
||||
pass # train
|
||||
pass # train
|
||||
|
||||
|
||||
# runs two trainers
|
||||
# LOCAL_RANK=0 trainer(1,2,3)
|
||||
# LOCAL_RANK=1 trainer(4,5,6)
|
||||
ctx = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=trainer,
|
||||
args={0: (1,2,3), 1: (4,5,6)},
|
||||
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
|
||||
log_dir="/tmp/foobar",
|
||||
redirects=Std.ALL, # write all worker stdout/stderr to a log file
|
||||
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
|
||||
)
|
||||
name="trainer",
|
||||
entrypoint=trainer,
|
||||
args={0: (1, 2, 3), 1: (4, 5, 6)},
|
||||
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
|
||||
log_dir="/tmp/foobar",
|
||||
redirects=Std.ALL, # write all worker stdout/stderr to a log file
|
||||
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
|
||||
)
|
||||
|
||||
# waits for all copies of trainer to finish
|
||||
ctx.wait()
|
||||
|
@ -165,9 +165,11 @@ def to_map(
|
||||
Example:
|
||||
::
|
||||
|
||||
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
|
||||
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
|
||||
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
|
||||
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
|
||||
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
|
||||
to_map(
|
||||
{0: Std.OUT, 1: Std.OUT}, local_world_size=2
|
||||
) # returns: {0: Std.OUT, 1: Std.OUT}
|
||||
"""
|
||||
if isinstance(val_or_map, Std):
|
||||
return dict.fromkeys(range(local_world_size), val_or_map)
|
||||
@ -304,7 +306,9 @@ class DefaultLogsSpecs(LogsSpecs):
|
||||
if not self._run_log_dir:
|
||||
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
|
||||
|
||||
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload]
|
||||
attempt_log_dir = os.path.join(
|
||||
self._run_log_dir, f"attempt_{restart_count}"
|
||||
) # type: ignore[call-overload]
|
||||
shutil.rmtree(attempt_log_dir, ignore_errors=True)
|
||||
os.makedirs(attempt_log_dir)
|
||||
|
||||
@ -868,9 +872,7 @@ class SubprocessContext(PContext):
|
||||
if result.is_failed():
|
||||
first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
|
||||
logger.error(
|
||||
"failed (exitcode: %s)"
|
||||
" local_rank: %s (pid: %s)"
|
||||
" of binary: %s",
|
||||
"failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s",
|
||||
first_failure.exitcode,
|
||||
first_failure.local_rank,
|
||||
first_failure.pid,
|
||||
|
@ -318,14 +318,14 @@ def record(
|
||||
error_handler = get_error_handler()
|
||||
error_handler.initialize()
|
||||
try:
|
||||
foobar()
|
||||
foobar()
|
||||
except ChildFailedError as e:
|
||||
_, failure = e.get_first_failure()
|
||||
error_handler.dump_error_file(failure.error_file, failure.exitcode)
|
||||
raise
|
||||
_, failure = e.get_first_failure()
|
||||
error_handler.dump_error_file(failure.error_file, failure.exitcode)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_handler.record_exception(e)
|
||||
raise
|
||||
error_handler.record_exception(e)
|
||||
raise
|
||||
|
||||
.. important:: use this decorator once per process at the top level method,
|
||||
typically this is the main method.
|
||||
@ -338,8 +338,9 @@ def record(
|
||||
def main():
|
||||
pass
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
if not error_handler:
|
||||
|
@ -120,11 +120,7 @@ of the following implementations that come with PyTorch:
|
||||
backend = C10dRendezvousBackend(store, "my_run_id")
|
||||
|
||||
rdzv_handler = DynamicRendezvousHandler.from_backend(
|
||||
run_id="my_run_id",
|
||||
store=store,
|
||||
backend=backend,
|
||||
min_nodes=2,
|
||||
max_nodes=4
|
||||
run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4
|
||||
)
|
||||
"""
|
||||
|
||||
|
@ -89,8 +89,14 @@ class RendezvousStoreInfo:
|
||||
addr = local_addr or socket.getfqdn()
|
||||
# When TCPStore is not shared, we fallback to get_free_port.
|
||||
port = server_port or get_free_port()
|
||||
store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type]
|
||||
store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type]
|
||||
store.set(
|
||||
RendezvousStoreInfo.MASTER_ADDR_KEY,
|
||||
addr.encode(encoding="UTF-8"), # type: ignore[arg-type]
|
||||
)
|
||||
store.set(
|
||||
RendezvousStoreInfo.MASTER_PORT_KEY,
|
||||
str(port).encode(encoding="UTF-8"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
|
||||
port = int(
|
||||
|
@ -413,9 +413,9 @@ class EtcdRendezvous:
|
||||
active_version = self.wait_for_peers(expected_version)
|
||||
state = json.loads(active_version.value)
|
||||
|
||||
assert (
|
||||
state["version"] == expected_version
|
||||
), "Logic error: failed to observe version mismatch"
|
||||
assert state["version"] == expected_version, (
|
||||
"Logic error: failed to observe version mismatch"
|
||||
)
|
||||
|
||||
return self.confirm_phase(expected_version, this_rank)
|
||||
|
||||
@ -533,9 +533,9 @@ class EtcdRendezvous:
|
||||
"Rendezvous version changed. Must try join the new one."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(state["participants"]) < self._num_max_workers
|
||||
), "Logic error: joinable rendezvous should always have space left"
|
||||
assert len(state["participants"]) < self._num_max_workers, (
|
||||
"Logic error: joinable rendezvous should always have space left"
|
||||
)
|
||||
|
||||
this_rank = len(state["participants"])
|
||||
state["participants"].append(this_rank)
|
||||
|
@ -86,11 +86,15 @@ def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
|
||||
from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
|
||||
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
|
||||
|
||||
|
||||
def create_my_rdzv(params: RendezvousParameters):
|
||||
return MyCustomRdzv(params)
|
||||
return MyCustomRdzv(params)
|
||||
|
||||
|
||||
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
|
||||
|
||||
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
|
||||
my_rdzv_handler = get_rendezvous_handler(
|
||||
"my_rdzv_backend_name", RendezvousParameters
|
||||
)
|
||||
"""
|
||||
return handler_registry.create_handler(params)
|
||||
|
@ -57,10 +57,10 @@ def get_all(store, rank: int, prefix: str, world_size: int):
|
||||
|
||||
::
|
||||
|
||||
values = get_all(store, 'torchelastic/data', 3)
|
||||
value1 = values[0] # retrieves the data for key torchelastic/data0
|
||||
value2 = values[1] # retrieves the data for key torchelastic/data1
|
||||
value3 = values[2] # retrieves the data for key torchelastic/data2
|
||||
values = get_all(store, "torchelastic/data", 3)
|
||||
value1 = values[0] # retrieves the data for key torchelastic/data0
|
||||
value2 = values[1] # retrieves the data for key torchelastic/data1
|
||||
value3 = values[2] # retrieves the data for key torchelastic/data2
|
||||
|
||||
"""
|
||||
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""
|
||||
This file includes private common utilities for FSDP.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import warnings
|
||||
@ -200,9 +201,9 @@ def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamH
|
||||
# handles, meaning no entry in `_fully_sharded_module_to_handles`
|
||||
if state._handle is None:
|
||||
return None
|
||||
assert (
|
||||
module in state._fully_sharded_module_to_handle
|
||||
), f"Expects a fully sharded module but got {module} on rank {state.rank}"
|
||||
assert module in state._fully_sharded_module_to_handle, (
|
||||
f"Expects a fully sharded module but got {module} on rank {state.rank}"
|
||||
)
|
||||
return state._fully_sharded_module_to_handle[module]
|
||||
else:
|
||||
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
|
||||
@ -255,9 +256,9 @@ def _named_parameters_with_duplicates(
|
||||
This API is required as some modules overwrite `named_parameters()` but do not support
|
||||
`remove_duplicate`.
|
||||
"""
|
||||
assert (
|
||||
"remove_duplicate" not in kwargs
|
||||
), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
|
||||
assert "remove_duplicate" not in kwargs, (
|
||||
"_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
|
||||
)
|
||||
kwargs["remove_duplicate"] = False
|
||||
try:
|
||||
ret = list(module.named_parameters(**kwargs))
|
||||
|
@ -190,9 +190,9 @@ class _ExecOrderData:
|
||||
return
|
||||
if self.is_first_iter:
|
||||
msg_prefix = "Forward order differs across ranks:"
|
||||
optional_local_indices: tuple[
|
||||
Optional[int], ...
|
||||
] = self._get_handle_indices(handle)
|
||||
optional_local_indices: tuple[Optional[int], ...] = (
|
||||
self._get_handle_indices(handle)
|
||||
)
|
||||
device = handle.device # guaranteed to be non-CPU
|
||||
num_valid_indices = sum(
|
||||
(index is not None) for index in optional_local_indices
|
||||
@ -250,8 +250,7 @@ class _ExecOrderData:
|
||||
(
|
||||
rank,
|
||||
world_indices[
|
||||
rank
|
||||
* num_valid_indices : (rank + 1)
|
||||
rank * num_valid_indices : (rank + 1)
|
||||
* num_valid_indices
|
||||
],
|
||||
)
|
||||
|
@ -586,7 +586,10 @@ class FlatParamHandle:
|
||||
)
|
||||
self._fsdp_extension = fsdp_extension
|
||||
self._init_flat_param_and_metadata(
|
||||
params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
|
||||
params,
|
||||
fully_sharded_module,
|
||||
self._aligned_numel,
|
||||
use_orig_params, # type: ignore[arg-type]
|
||||
)
|
||||
self._use_unsharded_views(as_params=False)
|
||||
|
||||
@ -978,9 +981,9 @@ class FlatParamHandle:
|
||||
shard_param_infos = self._get_shard_metadata(
|
||||
unsharded_start_idx, unsharded_end_idx
|
||||
)
|
||||
assert (
|
||||
len(shard_param_infos) == flat_param._num_params
|
||||
), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
|
||||
assert len(shard_param_infos) == flat_param._num_params, (
|
||||
f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
|
||||
)
|
||||
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
|
||||
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
|
||||
|
||||
@ -996,9 +999,9 @@ class FlatParamHandle:
|
||||
unsharded flat parameter specifying the shard.
|
||||
"""
|
||||
flat_param_offsets = self._get_flat_param_offsets()
|
||||
assert len(flat_param_offsets) == len(
|
||||
self.flat_param._numels_with_padding
|
||||
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
||||
assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), (
|
||||
f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
||||
)
|
||||
shard_param_infos: list[_ShardParamInfo] = []
|
||||
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
|
||||
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
|
||||
@ -1075,9 +1078,9 @@ class FlatParamHandle:
|
||||
else:
|
||||
chunk = chunks[rank]
|
||||
numel_to_pad = chunks[0].numel() - chunk.numel()
|
||||
assert (
|
||||
numel_to_pad >= 0
|
||||
), "Chunk's size should be at most the first chunk's size"
|
||||
assert numel_to_pad >= 0, (
|
||||
"Chunk's size should be at most the first chunk's size"
|
||||
)
|
||||
return chunk, numel_to_pad
|
||||
|
||||
@staticmethod
|
||||
@ -1302,7 +1305,8 @@ class FlatParamHandle:
|
||||
self._check_low_precision_shard()
|
||||
flat_param = self.flat_param
|
||||
_alloc_storage(
|
||||
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
|
||||
flat_param._mp_shard,
|
||||
flat_param._local_shard.size(), # type: ignore[attr-defined]
|
||||
)
|
||||
# `copy_()` implicitly casts to the low precision
|
||||
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
|
||||
@ -1498,7 +1502,8 @@ class FlatParamHandle:
|
||||
# default stream suffices since the default stream waits for the
|
||||
# unshard stream.
|
||||
_no_dispatch_record_stream(
|
||||
self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined]
|
||||
self.flat_param._mp_shard,
|
||||
self._device_handle.current_stream(), # type: ignore[attr-defined]
|
||||
)
|
||||
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
|
||||
|
||||
@ -1593,8 +1598,7 @@ class FlatParamHandle:
|
||||
f"but got {flat_param.grad.device}",
|
||||
)
|
||||
prev_iter_synced_gradients = (
|
||||
flat_param.grad.size()
|
||||
== flat_param._local_shard.size() # type: ignore[attr-defined]
|
||||
flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined]
|
||||
)
|
||||
if prev_iter_synced_gradients:
|
||||
# TODO (awgu): Gradient accumulation outside `no_sync()`
|
||||
@ -1668,8 +1672,7 @@ class FlatParamHandle:
|
||||
cast_grad_to_param_dtype_if_needed(flat_param)
|
||||
else:
|
||||
_p_assert(
|
||||
not self.uses_sharded_strategy
|
||||
or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
||||
not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
||||
"All sharded parameters that received a gradient in the "
|
||||
"post-backward should use `_saved_grad_shard`",
|
||||
)
|
||||
@ -2504,7 +2507,8 @@ class FlatParamHandle:
|
||||
"""Return the FQNs of the parameters present in this rank's shard."""
|
||||
fqns_in_shard: list[str] = []
|
||||
for fqn, shard_param_info in zip(
|
||||
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
|
||||
self.flat_param._fqns,
|
||||
self.flat_param._shard_param_infos, # type: ignore[attr-defined]
|
||||
):
|
||||
if shard_param_info.in_shard:
|
||||
fqns_in_shard.append(fqn)
|
||||
@ -2694,7 +2698,7 @@ def _safe_setattr_tensor_or_param(
|
||||
|
||||
|
||||
def _convert_to_params(
|
||||
tensors: list[Union[torch.Tensor, nn.Parameter]]
|
||||
tensors: list[Union[torch.Tensor, nn.Parameter]],
|
||||
) -> list[nn.Parameter]:
|
||||
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
|
||||
|
||||
|
@ -374,9 +374,9 @@ def foreach_reduce(
|
||||
for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)):
|
||||
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
|
||||
continue
|
||||
assert (
|
||||
unsharded_grad.size(shard_dim) % world_size == 0
|
||||
), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
|
||||
assert unsharded_grad.size(shard_dim) % world_size == 0, (
|
||||
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
|
||||
)
|
||||
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
|
||||
unsharded_grads[i] = torch.cat(chunks, dim=0)
|
||||
padded_unsharded_sizes = tuple(
|
||||
|
@ -26,9 +26,9 @@ if torch._running_with_deploy():
|
||||
else:
|
||||
|
||||
def detect_compiled_autograd():
|
||||
assert (
|
||||
not torch.compiler.is_compiling()
|
||||
), "`detect_compiled_autograd()` is designed to be called in eager mode"
|
||||
assert not torch.compiler.is_compiling(), (
|
||||
"`detect_compiled_autograd()` is designed to be called in eager mode"
|
||||
)
|
||||
global _compiled_autograd_enabled
|
||||
import torch._dynamo.compiled_autograd as ca
|
||||
|
||||
|
@ -304,9 +304,9 @@ class FSDPParam:
|
||||
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
|
||||
)
|
||||
split_factor = self._tp_spec.num_shards_map[shard_dim]
|
||||
assert (
|
||||
2 <= self._spmd_mesh.ndim <= 3
|
||||
), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
|
||||
assert 2 <= self._spmd_mesh.ndim <= 3, (
|
||||
f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
|
||||
)
|
||||
self._spmd_placements: tuple[Placement, ...]
|
||||
dp_shard_tp_placement = (
|
||||
(
|
||||
@ -520,8 +520,9 @@ class FSDPParam:
|
||||
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
|
||||
if hasattr(self, "_unsharded_param"):
|
||||
assert compiled_autograd_enabled()
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
self._unsharded_param
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autograd._unsafe_preserve_version_counter(self._unsharded_param),
|
||||
):
|
||||
# NOTE: Under compile, if an unsharded param goes through
|
||||
# resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
|
||||
@ -785,9 +786,9 @@ class FSDPParam:
|
||||
assert isinstance(grad, DTensor), f"{type(grad)}"
|
||||
placements = self._tp_spec.placements
|
||||
if placements != grad.placements:
|
||||
assert len(self._tp_spec.placements) == len(
|
||||
grad.placements
|
||||
), f"{self._tp_spec=} {grad.placements=}"
|
||||
assert len(self._tp_spec.placements) == len(grad.placements), (
|
||||
f"{self._tp_spec=} {grad.placements=}"
|
||||
)
|
||||
grad = grad.redistribute(placements=placements)
|
||||
grad = grad._local_tensor
|
||||
return grad
|
||||
@ -846,9 +847,9 @@ class FSDPParam:
|
||||
shard_dim = self.fsdp_placement.dim
|
||||
length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
|
||||
if local_tensor.size() != padded_sharded_size:
|
||||
assert (
|
||||
shard_dim == 0
|
||||
), f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
|
||||
assert shard_dim == 0, (
|
||||
f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
|
||||
)
|
||||
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
|
||||
padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(
|
||||
local_tensor
|
||||
|
@ -424,9 +424,9 @@ class FSDPParamGroup:
|
||||
if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
|
||||
# this means the native HSDP is not enabled,
|
||||
# but user may want to have a custom HSDP setup
|
||||
assert (
|
||||
self._all_reduce_hook is not None
|
||||
), "all reduce hook stream is specified but hook itself is missing."
|
||||
assert self._all_reduce_hook is not None, (
|
||||
"all reduce hook stream is specified but hook itself is missing."
|
||||
)
|
||||
all_reduce_stream = self._all_reduce_hook_stream
|
||||
else:
|
||||
all_reduce_stream = self.comm_ctx.all_reduce_stream
|
||||
@ -513,9 +513,10 @@ class FSDPParamGroup:
|
||||
else:
|
||||
raise ValueError(f"Unknown pass type: {pass_type}")
|
||||
target_fqn = target_fsdp_param_group._module_fqn
|
||||
with record_function(
|
||||
f"FSDP::{pass_type}_prefetch for {target_fqn}"
|
||||
), target_fsdp_param_group.use_training_state(training_state):
|
||||
with (
|
||||
record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"),
|
||||
target_fsdp_param_group.use_training_state(training_state),
|
||||
):
|
||||
async_op = target_fsdp_param_group.unshard_async_op
|
||||
target_fsdp_param_group.unshard(async_op)
|
||||
|
||||
@ -592,9 +593,9 @@ class FSDPParamGroup:
|
||||
def _register_state_dict_hooks(self) -> None:
|
||||
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
|
||||
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
|
||||
assert (
|
||||
num_pre_save_hooks == num_pre_load_hooks
|
||||
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
|
||||
assert num_pre_save_hooks == num_pre_load_hooks, (
|
||||
f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
|
||||
)
|
||||
if num_pre_save_hooks > 0:
|
||||
return # already registered
|
||||
modules_with_fsdp_params: set[nn.Module] = {
|
||||
@ -605,12 +606,12 @@ class FSDPParamGroup:
|
||||
self._to_sharded()
|
||||
|
||||
for module in modules_with_fsdp_params:
|
||||
self._module_to_pre_save_state_dict_hook_handle[
|
||||
module
|
||||
] = module.register_state_dict_pre_hook(to_sharded_hook)
|
||||
self._module_to_pre_load_state_dict_hook_handle[
|
||||
module
|
||||
] = module._register_load_state_dict_pre_hook(to_sharded_hook)
|
||||
self._module_to_pre_save_state_dict_hook_handle[module] = (
|
||||
module.register_state_dict_pre_hook(to_sharded_hook)
|
||||
)
|
||||
self._module_to_pre_load_state_dict_hook_handle[module] = (
|
||||
module._register_load_state_dict_pre_hook(to_sharded_hook)
|
||||
)
|
||||
|
||||
# Properties #
|
||||
@property
|
||||
|
@ -60,8 +60,7 @@ def fully_shard(
|
||||
mp_policy: MixedPrecisionPolicy = ...,
|
||||
offload_policy: OffloadPolicy = ...,
|
||||
ignored_params: Optional[set[nn.Parameter]] = ...,
|
||||
) -> FSDPModule:
|
||||
...
|
||||
) -> FSDPModule: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -74,8 +73,7 @@ def fully_shard(
|
||||
mp_policy: MixedPrecisionPolicy = ...,
|
||||
offload_policy: OffloadPolicy = ...,
|
||||
ignored_params: Optional[set[nn.Parameter]] = ...,
|
||||
) -> list[FSDPModule]:
|
||||
...
|
||||
) -> list[FSDPModule]: ...
|
||||
|
||||
|
||||
# The decorator adds a state object to `module` that can be accessed via
|
||||
|
@ -243,9 +243,9 @@ def _init_inter_node_process_group(
|
||||
if local_rank == my_local_rank:
|
||||
inter_node_pg = grp
|
||||
|
||||
assert (
|
||||
inter_node_pg is not None
|
||||
), f"{my_local_rank} expected to assign inter-node pg, but did not"
|
||||
assert inter_node_pg is not None, (
|
||||
f"{my_local_rank} expected to assign inter-node pg, but did not"
|
||||
)
|
||||
return inter_node_pg
|
||||
|
||||
|
||||
|
@ -145,9 +145,9 @@ def _unflatten_optim_state(
|
||||
dict will need to map these entries using the proper unflattened
|
||||
parameter IDs.
|
||||
"""
|
||||
assert (
|
||||
not shard_state or to_save
|
||||
), "If ``shard_state`` is True, ``to_save`` has to be True."
|
||||
assert not shard_state or to_save, (
|
||||
"If ``shard_state`` is True, ``to_save`` has to be True."
|
||||
)
|
||||
consolidated_state = _communicate_optim_state(
|
||||
fsdp_param_info,
|
||||
flat_param_state,
|
||||
@ -218,9 +218,9 @@ def _communicate_optim_state(
|
||||
):
|
||||
tensor_state[state_name] = value
|
||||
continue
|
||||
assert (
|
||||
fsdp_state.compute_device is not None
|
||||
), "compute_device has not been initialized"
|
||||
assert fsdp_state.compute_device is not None, (
|
||||
"compute_device has not been initialized"
|
||||
)
|
||||
if value.device.type != fsdp_state.compute_device.type:
|
||||
value = value.to(fsdp_state.compute_device)
|
||||
# Assume that positive-dimension tensor optimizer state
|
||||
@ -394,7 +394,10 @@ def _shard_orig_param_state(
|
||||
and value.dim() > 0
|
||||
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
|
||||
):
|
||||
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
|
||||
value = value.flatten()[
|
||||
intra_param_start_idx : intra_param_end_idx # type: ignore[operator]
|
||||
+ 1
|
||||
].clone()
|
||||
new_optim_state[state_name] = value
|
||||
return new_optim_state
|
||||
|
||||
@ -489,9 +492,9 @@ def _flatten_optim_state_dict(
|
||||
if flat_state:
|
||||
flat_osd_state[key] = flat_state
|
||||
elif use_orig_params:
|
||||
assert (
|
||||
len(fqns) == 1
|
||||
), f"use_orig_params is True but there are multiple FQNs, {fqns}."
|
||||
assert len(fqns) == 1, (
|
||||
f"use_orig_params is True but there are multiple FQNs, {fqns}."
|
||||
)
|
||||
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
|
||||
state = optim.state.get(param, None) # type: ignore[call-overload]
|
||||
if state is not None:
|
||||
@ -570,14 +573,13 @@ def _flatten_optim_state(
|
||||
flat_param = handle.flat_param
|
||||
num_unflat_params = len(unflat_param_names)
|
||||
assert num_unflat_params > 0, (
|
||||
"Expects at least one unflattened parameter corresponding to the "
|
||||
"flat parameter"
|
||||
"Expects at least one unflattened parameter corresponding to the flat parameter"
|
||||
)
|
||||
unflat_param_shapes = flat_param._shapes
|
||||
num_unflat_param_shapes = len(unflat_param_shapes)
|
||||
assert (
|
||||
num_unflat_params == num_unflat_param_shapes
|
||||
), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
|
||||
assert num_unflat_params == num_unflat_param_shapes, (
|
||||
f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
|
||||
)
|
||||
|
||||
# Check if these unflattened parameters have any optimizer state
|
||||
has_state = [
|
||||
@ -759,8 +761,7 @@ def _flatten_tensor_optim_state(
|
||||
flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
|
||||
flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined]
|
||||
assert flat_tensor.shape == flat_param_shape, (
|
||||
f"tensor optim state: {flat_tensor.shape} "
|
||||
f"flat parameter: {flat_param_shape}"
|
||||
f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}"
|
||||
)
|
||||
return flat_tensor
|
||||
|
||||
@ -1065,9 +1066,9 @@ def _get_param_key_to_param(
|
||||
"""
|
||||
clean_fqn_to_curr_fqn: dict[str, str] = {}
|
||||
if is_named_optimizer:
|
||||
assert (
|
||||
param_to_fqns is not None and flat_param_to_fqn is not None
|
||||
), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
|
||||
assert param_to_fqns is not None and flat_param_to_fqn is not None, (
|
||||
"The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
|
||||
)
|
||||
assert model is not None
|
||||
for key, _ in _named_parameters_with_duplicates(model):
|
||||
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
|
||||
@ -1150,9 +1151,9 @@ def _check_missing_keys_on_rank(
|
||||
continue
|
||||
param_key = optim_state_key_to_param_key[r0_optim_state_key]
|
||||
if isinstance(param_key, int):
|
||||
assert param_key >= 0 and param_key < len(
|
||||
param_key_to_param
|
||||
), "Check the `param_key_to_param` construction"
|
||||
assert param_key >= 0 and param_key < len(param_key_to_param), (
|
||||
"Check the `param_key_to_param` construction"
|
||||
)
|
||||
# We cannot use FSDPState.compute_device as this API is a global view.
|
||||
device = _get_pg_default_device(group)
|
||||
num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)
|
||||
|
@ -121,9 +121,9 @@ def _all_gather_dtensor(
|
||||
"""
|
||||
All gather a DTensor in its sharded dimension and return the local tensor.
|
||||
"""
|
||||
assert (
|
||||
root_mesh == tensor.device_mesh
|
||||
), "The device mesh of a tensor should be a root mesh."
|
||||
assert root_mesh == tensor.device_mesh, (
|
||||
"The device mesh of a tensor should be a root mesh."
|
||||
)
|
||||
|
||||
placements = list(copy.deepcopy(tensor.placements))
|
||||
# FSDP placements: [Shard(0)] -> [Replicate()]
|
||||
|
@ -466,9 +466,9 @@ def _local_pre_load_state_dict_hook(
|
||||
)
|
||||
return
|
||||
load_tensor = state_dict[fqn]
|
||||
assert isinstance(
|
||||
load_tensor, ShardedTensor
|
||||
), "Tensors in local_state_dict should be ShardedTensor."
|
||||
assert isinstance(load_tensor, ShardedTensor), (
|
||||
"Tensors in local_state_dict should be ShardedTensor."
|
||||
)
|
||||
|
||||
# Convert the ShardedTensor to a Tensor.
|
||||
flat_param = _module_handle(fsdp_state, module).flat_param
|
||||
|
@ -143,9 +143,9 @@ class _ExecOrderTracer:
|
||||
named_params = list(module.named_parameters())
|
||||
curr_module = exec_info.curr_module
|
||||
if named_params:
|
||||
assert (
|
||||
curr_module in exec_info.module_to_param_usage_infos
|
||||
), "The current module should have already been processed by a patched `call_module`"
|
||||
assert curr_module in exec_info.module_to_param_usage_infos, (
|
||||
"The current module should have already been processed by a patched `call_module`"
|
||||
)
|
||||
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
|
||||
_ParamUsageInfo(module, named_params)
|
||||
)
|
||||
|
@ -185,9 +185,9 @@ def _unshard_fsdp_state_params(
|
||||
yield
|
||||
return
|
||||
|
||||
assert (
|
||||
handle._training_state == HandleTrainingState.IDLE
|
||||
), f"Expects the handle training to be IDLE but got {handle._training_state}"
|
||||
assert handle._training_state == HandleTrainingState.IDLE, (
|
||||
f"Expects the handle training to be IDLE but got {handle._training_state}"
|
||||
)
|
||||
|
||||
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
|
||||
|
||||
|
@ -306,16 +306,21 @@ class FullStateDictConfig(StateDictConfig):
|
||||
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
|
||||
>>> state = fsdp.state_dict()
|
||||
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
|
||||
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
|
||||
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
|
||||
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
|
||||
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
|
||||
>>> if dist.get_rank() == 0:
|
||||
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
|
||||
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
|
||||
>>> state_dict = torch.load("my_checkpoint.pt")
|
||||
>>> model.load_state_dict(state_dict)
|
||||
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
|
||||
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
|
||||
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
|
||||
>>> fsdp = FSDP(
|
||||
... model,
|
||||
... device_id=torch.cuda.current_device(),
|
||||
... auto_wrap_policy=...,
|
||||
... sync_module_states=True,
|
||||
... )
|
||||
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
|
||||
|
||||
Attributes:
|
||||
|
@ -723,9 +723,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
if prev_state_dict_type is None:
|
||||
prev_state_dict_type = submodule._state_dict_type
|
||||
else:
|
||||
assert (
|
||||
prev_state_dict_type == submodule._state_dict_type
|
||||
), "All FSDP modules should have the same state_dict_type."
|
||||
assert prev_state_dict_type == submodule._state_dict_type, (
|
||||
"All FSDP modules should have the same state_dict_type."
|
||||
)
|
||||
if prev_state_dict_config is None:
|
||||
prev_state_dict_config = submodule._state_dict_config
|
||||
else:
|
||||
@ -738,7 +738,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
assert isinstance(
|
||||
submodule._optim_state_dict_config,
|
||||
type(prev_optim_state_dict_config),
|
||||
), "All FSDP modules must have the same type of optim_state_dict_config."
|
||||
), (
|
||||
"All FSDP modules must have the same type of optim_state_dict_config."
|
||||
)
|
||||
|
||||
submodule._state_dict_type = state_dict_type
|
||||
submodule._state_dict_config = state_dict_config
|
||||
@ -2153,9 +2155,9 @@ def _get_param_to_fqn(
|
||||
"""
|
||||
param_to_param_names = _get_param_to_fqns(model)
|
||||
for param_names in param_to_param_names.values():
|
||||
assert (
|
||||
len(param_names) > 0
|
||||
), "`_get_param_to_fqns()` should not construct empty lists"
|
||||
assert len(param_names) > 0, (
|
||||
"`_get_param_to_fqns()` should not construct empty lists"
|
||||
)
|
||||
if len(param_names) > 1:
|
||||
raise RuntimeError(
|
||||
"Each parameter should only map to one parameter name but got "
|
||||
|
@ -112,20 +112,16 @@ class ShardedGradScaler(GradScaler):
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
@overload
|
||||
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
def scale(self, outputs: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
@overload
|
||||
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
...
|
||||
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ...
|
||||
|
||||
@overload
|
||||
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
|
||||
...
|
||||
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ...
|
||||
|
||||
@overload
|
||||
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
|
||||
...
|
||||
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ...
|
||||
|
||||
def scale(
|
||||
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||
@ -323,8 +319,10 @@ class ShardedGradScaler(GradScaler):
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
|
||||
reason = (
|
||||
"new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
|
||||
torch.FloatTensor with requires_grad=False."
|
||||
)
|
||||
assert new_scale.device.type == self._device, reason
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
|
@ -61,9 +61,9 @@ def _post_order_apply(
|
||||
"Non-root modules should have their module name set but got "
|
||||
f"an empty module name for {module}"
|
||||
)
|
||||
assert isinstance(
|
||||
optional_module, nn.Module
|
||||
), f"fn should return None or an nn.Module but got {optional_module}"
|
||||
assert isinstance(optional_module, nn.Module), (
|
||||
f"fn should return None or an nn.Module but got {optional_module}"
|
||||
)
|
||||
setattr(parent_module, module_name, optional_module)
|
||||
|
||||
_post_order_apply_inner(root_module, "", None)
|
||||
@ -575,9 +575,9 @@ class _ConfigAutoWrap:
|
||||
)
|
||||
_ConfigAutoWrap.in_autowrap_context = True
|
||||
# Get and save the wrapper cls for the context.
|
||||
assert (
|
||||
"wrapper_cls" in kwargs.keys()
|
||||
), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
|
||||
assert "wrapper_cls" in kwargs.keys(), (
|
||||
"Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
|
||||
)
|
||||
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
|
||||
del kwargs["wrapper_cls"]
|
||||
# Save the rest.
|
||||
|
@ -183,8 +183,7 @@ def parse_args(args):
|
||||
def launch(args):
|
||||
if args.no_python and not args.use_env:
|
||||
raise ValueError(
|
||||
"When using the '--no-python' flag,"
|
||||
" you must also set the '--use-env' flag."
|
||||
"When using the '--no-python' flag, you must also set the '--use-env' flag."
|
||||
)
|
||||
run(args)
|
||||
|
||||
|
@ -39,7 +39,10 @@ _REMOTE_MODULE_PICKLED_ATTRIBUTES = (
|
||||
"module_rref",
|
||||
)
|
||||
|
||||
_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc]
|
||||
_SerializedRemoteModule = collections.namedtuple( # type: ignore[misc]
|
||||
"_SerializedRemoteModule",
|
||||
_REMOTE_MODULE_PICKLED_ATTRIBUTES,
|
||||
)
|
||||
|
||||
# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
|
||||
# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
|
||||
|
@ -26,15 +26,15 @@ sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
|
||||
|
||||
|
||||
def get_arg_return_types_from_interface(module_interface):
|
||||
assert getattr(
|
||||
module_interface, "__torch_script_interface__", False
|
||||
), "Expect a TorchScript class interface decorated by @torch.jit.interface."
|
||||
assert getattr(module_interface, "__torch_script_interface__", False), (
|
||||
"Expect a TorchScript class interface decorated by @torch.jit.interface."
|
||||
)
|
||||
qualified_name = torch._jit_internal._qualified_name(module_interface)
|
||||
cu = torch.jit._state._python_cu
|
||||
module_interface_c = cu.get_interface(qualified_name)
|
||||
assert (
|
||||
"forward" in module_interface_c.getMethodNames()
|
||||
), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
|
||||
assert "forward" in module_interface_c.getMethodNames(), (
|
||||
f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
|
||||
)
|
||||
method_schema = module_interface_c.getMethod("forward")
|
||||
|
||||
arg_str_list = []
|
||||
|
@ -5,6 +5,7 @@ optimizer locally on the workers where the parameters live. The distributed
|
||||
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
|
||||
apply the gradients on each worker.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
@ -44,10 +44,10 @@ def _apply_optimizer_in_backward(
|
||||
param_1 = next(params_generator)
|
||||
remainder_params = list(params_generator)
|
||||
|
||||
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
|
||||
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
|
||||
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02})
|
||||
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04})
|
||||
|
||||
model(...).sum().backward() # after backward, parameters will already
|
||||
model(...).sum().backward() # after backward, parameters will already
|
||||
# have their registered optimizer(s) applied.
|
||||
|
||||
"""
|
||||
@ -111,7 +111,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Opt
|
||||
List[torch.optim.Optimizer]: the in-backward optimizers.
|
||||
|
||||
Example::
|
||||
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
|
||||
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
|
||||
optims = _get_optimizers_in_backward(model)
|
||||
"""
|
||||
optims: list[torch.optim.Optimizer] = []
|
||||
|
@ -147,12 +147,10 @@ class _NamedOptimizer(optim.Optimizer):
|
||||
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
|
||||
|
||||
@overload
|
||||
def step(self, closure: None = ...) -> None:
|
||||
...
|
||||
def step(self, closure: None = ...) -> None: ...
|
||||
|
||||
@overload
|
||||
def step(self, closure: Callable[[], float]) -> float:
|
||||
...
|
||||
def step(self, closure: Callable[[], float]) -> float: ...
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
"""
|
||||
|
@ -4,6 +4,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
r"""Zero Redundancy Optimizer."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import enum
|
||||
@ -262,9 +263,9 @@ class _OverlapInfo:
|
||||
meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
|
||||
in preparation for the next iteration.
|
||||
"""
|
||||
assert (
|
||||
len(self.broadcast_handles) == self.num_bucket_assignments
|
||||
), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
|
||||
assert len(self.broadcast_handles) == self.num_bucket_assignments, (
|
||||
f"Missing at least one broadcast handle on rank {dist.get_rank()}"
|
||||
)
|
||||
_ = [x.wait() for x in self.broadcast_handles]
|
||||
self.broadcast_handles.clear()
|
||||
|
||||
@ -909,9 +910,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
params_per_rank = overlap_info.params_per_rank
|
||||
offsets = overlap_info.offsets
|
||||
|
||||
self._bucket_assignments_per_rank_cache[assigned_rank][
|
||||
bucket_index
|
||||
] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
|
||||
self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = (
|
||||
_DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
|
||||
)
|
||||
if self.global_rank == assigned_rank:
|
||||
offsets[bucket_index] = len(params_per_rank[assigned_rank])
|
||||
params_per_rank[assigned_rank].extend(bucket_params)
|
||||
@ -927,9 +928,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
mapping bucket indices to :class:`_DDPBucketAssignment` s for each
|
||||
rank.
|
||||
"""
|
||||
assert (
|
||||
self._overlap_with_ddp
|
||||
), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
|
||||
assert self._overlap_with_ddp, (
|
||||
"`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
|
||||
)
|
||||
if len(self._bucket_assignments_per_rank_cache) > 0:
|
||||
return self._bucket_assignments_per_rank_cache
|
||||
|
||||
@ -1076,9 +1077,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
"Specifying `gradients` should not "
|
||||
"be used when `overlap_with_ddp=False`"
|
||||
)
|
||||
assert (
|
||||
closure is None
|
||||
), "`closure` is not supported when using a local functional optimizer"
|
||||
assert closure is None, (
|
||||
"`closure` is not supported when using a local functional optimizer"
|
||||
)
|
||||
loss = self.optim.step(gradients=gradients)
|
||||
|
||||
# Sync any updated attributes in the local optimizer to the exposed
|
||||
@ -1221,9 +1222,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
for rank, local_state_dict in enumerate(self._all_state_dicts):
|
||||
local_param_groups = local_state_dict["param_groups"]
|
||||
global_param_groups = self._partition_parameters()[rank]
|
||||
assert len(local_param_groups) == len(
|
||||
global_param_groups
|
||||
), "Mismatch between number of local and global parameter groups"
|
||||
assert len(local_param_groups) == len(global_param_groups), (
|
||||
"Mismatch between number of local and global parameter groups"
|
||||
)
|
||||
|
||||
for local_param_group, global_param_group in zip(
|
||||
local_param_groups, global_param_groups
|
||||
@ -1233,9 +1234,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
local_param_indices = local_param_group["params"]
|
||||
global_params = global_param_group["params"]
|
||||
|
||||
assert len(local_param_indices) == len(
|
||||
global_params
|
||||
), "Mismatch between number of local and global parameters in parameter group"
|
||||
assert len(local_param_indices) == len(global_params), (
|
||||
"Mismatch between number of local and global parameters in parameter group"
|
||||
)
|
||||
for local_param_index, global_param in zip(
|
||||
local_param_indices, global_params
|
||||
):
|
||||
@ -1268,9 +1269,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
dst_param_groups (list[dict]): parameter groups giving the
|
||||
attribute settings to set.
|
||||
"""
|
||||
assert len(src_param_groups) == len(
|
||||
dst_param_groups
|
||||
), "Mismatch between number of source and destination parameter groups"
|
||||
assert len(src_param_groups) == len(dst_param_groups), (
|
||||
"Mismatch between number of source and destination parameter groups"
|
||||
)
|
||||
for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
|
||||
# Sync all attributes except the parameters
|
||||
for attr in filter(lambda x: x != "params", src_param_group.keys()):
|
||||
@ -1479,9 +1480,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
|
||||
The local optimizer is saved in ``self.optim``.
|
||||
"""
|
||||
assert (
|
||||
self._optim_constructor is not None
|
||||
), "The local optimizer class has not been set"
|
||||
assert self._optim_constructor is not None, (
|
||||
"The local optimizer class has not been set"
|
||||
)
|
||||
|
||||
param_groups = self._partition_parameters()[self.rank]
|
||||
# `overlap_with_ddp=True` requires a local functional optimizer
|
||||
@ -1508,7 +1509,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
"error due to an empty parameter list",
|
||||
self._optim_constructor,
|
||||
)
|
||||
self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef]
|
||||
self.optim: Any = self._optim_constructor(
|
||||
params, **self._optim_defaults
|
||||
) # type: ignore[no-redef]
|
||||
|
||||
# Log information about the DDP and ZeRO bucketing
|
||||
if dist.get_debug_level() != dist.DebugLevel.OFF:
|
||||
@ -1531,7 +1534,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
else:
|
||||
# NOTE: Passing `param_groups` into the local optimizer constructor
|
||||
# bypasses the empty parameter list check
|
||||
self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef]
|
||||
self.optim: Optimizer = self._optim_constructor(
|
||||
param_groups, **self._optim_defaults
|
||||
) # type: ignore[no-redef]
|
||||
|
||||
# TODO: Manually add `self.param_groups` if using a functional
|
||||
# optimizer; remove this if/when the functional optimizers support
|
||||
|
@ -123,12 +123,11 @@ def _insert_stage_symbolic_backward(
|
||||
# getitem calls. If we have a target other than getitem in this
|
||||
# (forward-only) code, there is a bug.
|
||||
assert node.target == operator.getitem, (
|
||||
"Found non-getitem call in forward pass. "
|
||||
"Please report a bug to PiPPy"
|
||||
"Found non-getitem call in forward pass. Please report a bug to PiPPy"
|
||||
)
|
||||
assert len(node.args) == 2, (
|
||||
"Found malformed getitem call. Please report a bug to PiPPy"
|
||||
)
|
||||
assert (
|
||||
len(node.args) == 2
|
||||
), "Found malformed getitem call. Please report a bug to PiPPy"
|
||||
indexed_value, node_idx = tuple(node.args)
|
||||
|
||||
# indexed_value is a collection that we are indexing into. It could
|
||||
@ -249,8 +248,8 @@ class LossWrapper(torch.nn.Module):
|
||||
targets value into the loss function, and get and return the loss value, which will
|
||||
be backpropagated by PiPPy. The above class would then be instantiated like::
|
||||
|
||||
model = ... # instantiate the model
|
||||
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
|
||||
model = ... # instantiate the model
|
||||
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
|
||||
|
||||
wrapper = MyModelWrapper(model, loss_fn)
|
||||
pipe = Pipe.from_tracing(wrapper, ...)
|
||||
@ -818,9 +817,9 @@ class Pipe(torch.nn.Module):
|
||||
|
||||
# Get submodule
|
||||
callee = root.get_submodule(callee_name)
|
||||
assert not hasattr(
|
||||
callee, param_fqn
|
||||
), f"Module {callee_name} already has a parameter named {param_fqn}"
|
||||
assert not hasattr(callee, param_fqn), (
|
||||
f"Module {callee_name} already has a parameter named {param_fqn}"
|
||||
)
|
||||
|
||||
# Assign the parameter to the submodule
|
||||
if is_buffer:
|
||||
@ -979,7 +978,7 @@ class Pipe(torch.nn.Module):
|
||||
else:
|
||||
logger.debug("Pipeline is in inference mode, backward pass not generated")
|
||||
|
||||
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
|
||||
logger.debug(f"Full pipe model:\n{split}") # noqa: G004
|
||||
|
||||
return Pipe(
|
||||
split,
|
||||
@ -1184,7 +1183,7 @@ def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Specified target {qualname} referenced "
|
||||
f'nonexistent module {".".join(atoms[: i + 1])}'
|
||||
f"nonexistent module {'.'.join(atoms[: i + 1])}"
|
||||
) from e
|
||||
|
||||
mod_to_wrap = getattr(predecessor_module, atoms[-1])
|
||||
|
@ -306,17 +306,17 @@ def stage_backward(
|
||||
if isinstance(output_val, torch.Tensor):
|
||||
if not output_val.requires_grad and output_val.grad_fn is None:
|
||||
return
|
||||
assert isinstance(
|
||||
grad_val, (torch.Tensor, type(None))
|
||||
), f"Expected Tensor or None gradient but got {type(grad_val)}"
|
||||
assert isinstance(grad_val, (torch.Tensor, type(None))), (
|
||||
f"Expected Tensor or None gradient but got {type(grad_val)}"
|
||||
)
|
||||
stage_output_tensors.append(output_val)
|
||||
output_grad_tensors.append(grad_val)
|
||||
elif isinstance(output_val, (tuple, list)):
|
||||
if grad_val is None:
|
||||
return
|
||||
assert isinstance(
|
||||
grad_val, (tuple, list)
|
||||
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
|
||||
assert isinstance(grad_val, (tuple, list)), (
|
||||
f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
|
||||
)
|
||||
assert len(output_val) == len(grad_val)
|
||||
for ov, gv in zip(output_val, grad_val):
|
||||
extract_tensors_with_grads(
|
||||
@ -350,7 +350,8 @@ def stage_backward(
|
||||
)
|
||||
|
||||
torch.autograd.backward(
|
||||
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
|
||||
stage_output_tensors,
|
||||
grad_tensors=output_grad_tensors, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Extract gradients wrt the input values
|
||||
|
@ -140,9 +140,9 @@ def _shard_dict_of_args(
|
||||
real_num_chunks = num_chunks
|
||||
first_tensor = True
|
||||
|
||||
assert len(args_dict) == len(
|
||||
args_chunk_spec
|
||||
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
||||
assert len(args_dict) == len(args_chunk_spec), (
|
||||
f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
||||
)
|
||||
|
||||
for arg_key, arg in args_dict.items():
|
||||
flat, spec = tree_flatten(arg)
|
||||
|
@ -706,7 +706,9 @@ class Schedule1F1B(PipelineScheduleSingle):
|
||||
recv_work.wait()
|
||||
|
||||
# Compute
|
||||
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
||||
output = self._stage.forward_one_chunk(
|
||||
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
|
||||
) # type: ignore[index]
|
||||
|
||||
# Clear previous chunk's forward sends (hopefully they have well
|
||||
# finished, otherwise, we are heavily communication bound, in which
|
||||
@ -762,7 +764,9 @@ class Schedule1F1B(PipelineScheduleSingle):
|
||||
fuse_work.wait()
|
||||
|
||||
# Now do the fwd
|
||||
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
||||
output = self._stage.forward_one_chunk(
|
||||
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
|
||||
) # type: ignore[index]
|
||||
|
||||
# Compute loss
|
||||
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
||||
@ -992,9 +996,9 @@ def _add_send_recv(
|
||||
progress = False
|
||||
# go in order of ranks even if dict keys aren't ordered
|
||||
for rank in sorted(compute_actions):
|
||||
assert (
|
||||
len(compute_actions[rank]) > 0
|
||||
), f"{rank=}, {len(compute_actions[rank])=}"
|
||||
assert len(compute_actions[rank]) > 0, (
|
||||
f"{rank=}, {len(compute_actions[rank])=}"
|
||||
)
|
||||
action = compute_actions[rank][0]
|
||||
|
||||
if not _ready_to_schedule(action, prev_actions[rank]):
|
||||
@ -1026,9 +1030,9 @@ def _validate_schedule(
|
||||
num_stages: int,
|
||||
num_microbatches: int,
|
||||
) -> dict[int, int]:
|
||||
assert (
|
||||
len(actions) == pp_group_size
|
||||
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
|
||||
assert len(actions) == pp_group_size, (
|
||||
f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
|
||||
)
|
||||
for rank in range(pp_group_size):
|
||||
assert rank in actions, f"Schedule is missing actions for rank {rank}"
|
||||
|
||||
@ -1048,36 +1052,36 @@ def _validate_schedule(
|
||||
for action in actions[rank]:
|
||||
if action is None:
|
||||
continue
|
||||
assert isinstance(
|
||||
action, _Action
|
||||
), f"Got an invalid action: {action}, expected instance of _Action"
|
||||
assert isinstance(action, _Action), (
|
||||
f"Got an invalid action: {action}, expected instance of _Action"
|
||||
)
|
||||
s_id = action.stage_index
|
||||
ctype = action.computation_type
|
||||
mb_id = action.microbatch_index
|
||||
if ctype == F:
|
||||
stage_actions[s_id][F].add(mb_id)
|
||||
elif ctype == B:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][F]
|
||||
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
assert mb_id in stage_actions[s_id][F], (
|
||||
f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
)
|
||||
stage_actions[s_id][B].add(mb_id)
|
||||
elif ctype == I:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][F]
|
||||
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
assert mb_id in stage_actions[s_id][F], (
|
||||
f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
)
|
||||
stage_actions[s_id][I].add(mb_id)
|
||||
elif ctype == W:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][I]
|
||||
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"
|
||||
assert mb_id in stage_actions[s_id][I], (
|
||||
f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"
|
||||
)
|
||||
stage_actions[s_id][W].add(mb_id)
|
||||
if s_id not in stage_index_to_rank_mapping:
|
||||
stage_index_to_rank_mapping[s_id] = rank
|
||||
else:
|
||||
existing_rank = stage_index_to_rank_mapping[s_id]
|
||||
assert (
|
||||
rank == existing_rank
|
||||
), f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
|
||||
assert rank == existing_rank, (
|
||||
f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
|
||||
)
|
||||
|
||||
for s_id in stage_actions:
|
||||
f_mb = len(stage_actions[s_id][F])
|
||||
@ -1085,14 +1089,14 @@ def _validate_schedule(
|
||||
i_mb = len(stage_actions[s_id][I])
|
||||
w_mb = len(stage_actions[s_id][W])
|
||||
|
||||
assert (
|
||||
f_mb == num_microbatches
|
||||
), f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
|
||||
assert f_mb == num_microbatches, (
|
||||
f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
|
||||
)
|
||||
|
||||
assert (
|
||||
b_mb + (i_mb + w_mb) // 2 == num_microbatches
|
||||
), f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
|
||||
assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, (
|
||||
f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
|
||||
but got B={b_mb}, I={i_mb}, W={w_mb}"
|
||||
)
|
||||
return stage_index_to_rank_mapping
|
||||
|
||||
|
||||
@ -1289,9 +1293,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
computation_type = action.computation_type
|
||||
mb_index = action.microbatch_index
|
||||
stage_index = action.stage_index
|
||||
assert (
|
||||
mb_index is not None
|
||||
), "All currently supported action types require valid microbatch_index"
|
||||
assert mb_index is not None, (
|
||||
"All currently supported action types require valid microbatch_index"
|
||||
)
|
||||
if computation_type == _ComputationType.FORWARD:
|
||||
# perform forward computation
|
||||
stage = stage_index_to_stage[stage_index]
|
||||
@ -1362,9 +1366,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
computation_type = prev_rank_action.computation_type
|
||||
mb_index = prev_rank_action.microbatch_index
|
||||
stage_index = prev_rank_action.stage_index
|
||||
assert (
|
||||
mb_index is not None
|
||||
), "All currently supported action types require valid microbatch_index"
|
||||
assert mb_index is not None, (
|
||||
"All currently supported action types require valid microbatch_index"
|
||||
)
|
||||
# Only handle sends for the forward from a previous rank
|
||||
if computation_type == _ComputationType.FORWARD:
|
||||
# If not the last stage, then receive fwd activations
|
||||
@ -1393,9 +1397,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
computation_type = next_rank_action.computation_type
|
||||
mb_index = next_rank_action.microbatch_index
|
||||
stage_index = next_rank_action.stage_index
|
||||
assert (
|
||||
mb_index is not None
|
||||
), "All currently supported action types require valid microbatch_index"
|
||||
assert mb_index is not None, (
|
||||
"All currently supported action types require valid microbatch_index"
|
||||
)
|
||||
# Only handle receives for the backwards from a next rank
|
||||
if computation_type in (FORWARD, BACKWARD_WEIGHT):
|
||||
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
||||
@ -1503,9 +1507,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
|
||||
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
|
||||
# that it does not exist if it was created from a compute_comms schedule.
|
||||
assert (
|
||||
self.pipeline_order_with_comms is not None
|
||||
), "Must initialize compute_comms schedule before dump_csv"
|
||||
assert self.pipeline_order_with_comms is not None, (
|
||||
"Must initialize compute_comms schedule before dump_csv"
|
||||
)
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
for rank in self.pipeline_order_with_comms:
|
||||
@ -1541,9 +1545,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
stage.stage_index: stage for stage in self._stages
|
||||
}
|
||||
|
||||
assert (
|
||||
self.pipeline_order_with_comms is not None
|
||||
), "Must call _load_actions() before calling _step_microbatches()"
|
||||
assert self.pipeline_order_with_comms is not None, (
|
||||
"Must call _load_actions() before calling _step_microbatches()"
|
||||
)
|
||||
|
||||
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
|
||||
bwd_recv_ops: dict[tuple[int, int], Work] = {}
|
||||
@ -1562,9 +1566,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
unshard_ops[stage_idx].wait()
|
||||
del unshard_ops[stage_idx]
|
||||
unsharded_stages.add(stage_idx)
|
||||
assert (
|
||||
stage_idx in unsharded_stages
|
||||
), f"Attempted to compute on sharded {stage_idx=}"
|
||||
assert stage_idx in unsharded_stages, (
|
||||
f"Attempted to compute on sharded {stage_idx=}"
|
||||
)
|
||||
|
||||
# count either full_backward or backward_weight together, to determine when to sync DP grads
|
||||
backward_counter: Counter[int] = Counter()
|
||||
@ -1606,7 +1610,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
assert (
|
||||
stage_idx,
|
||||
mb_index,
|
||||
) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward"
|
||||
) not in fwd_recv_ops, (
|
||||
"Recv twice for {stage_idx=} {mb_index=} without executing forward"
|
||||
)
|
||||
fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
||||
stage.get_fwd_recv_ops(mb_index)
|
||||
)
|
||||
@ -1614,7 +1620,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
assert (
|
||||
stage_idx,
|
||||
mb_index,
|
||||
) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward"
|
||||
) not in bwd_recv_ops, (
|
||||
"Recv twice for {stage_idx=} {mb_index=} without executing backward"
|
||||
)
|
||||
bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
||||
stage.get_bwd_recv_ops(mb_index)
|
||||
)
|
||||
@ -1627,12 +1635,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator]
|
||||
elif comp_type == RESHARD:
|
||||
if stage_uses_fsdp:
|
||||
assert (
|
||||
stage_idx in unsharded_stages
|
||||
), f"Resharding {stage_idx=} without unsharding"
|
||||
assert (
|
||||
stage_idx not in unshard_ops
|
||||
), f"Resharding {stage_idx=} before finishing unshard"
|
||||
assert stage_idx in unsharded_stages, (
|
||||
f"Resharding {stage_idx=} without unsharding"
|
||||
)
|
||||
assert stage_idx not in unshard_ops, (
|
||||
f"Resharding {stage_idx=} before finishing unshard"
|
||||
)
|
||||
stage.submod.reshard() # type: ignore[operator]
|
||||
elif comp_type == FORWARD:
|
||||
if stage_uses_fsdp:
|
||||
@ -1739,7 +1747,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
)
|
||||
# TODO(whc) what is the best practice for printing a multiline log?
|
||||
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
|
||||
print(_format_pipeline_order(self.pipeline_order_with_comms, error_step_number=time_step)) # type: ignore[arg-type]
|
||||
print(
|
||||
_format_pipeline_order(
|
||||
self.pipeline_order_with_comms, # type: ignore[arg-type]
|
||||
error_step_number=time_step,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
|
||||
|
@ -243,16 +243,16 @@ class _PipelineStageBase(ABC):
|
||||
configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
|
||||
which could show up as hangs, silent corruption, or other errors.
|
||||
"""
|
||||
assert (
|
||||
self._outputs_meta is None
|
||||
), "Attempting to reconfigure output_meta, which is not supported"
|
||||
assert self._outputs_meta is None, (
|
||||
"Attempting to reconfigure output_meta, which is not supported"
|
||||
)
|
||||
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
|
||||
|
||||
def get_outputs_meta(self) -> tuple[torch.Tensor, ...]:
|
||||
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
|
||||
assert (
|
||||
self._outputs_meta is not None
|
||||
), "Attempted to get_outputs_meta() without configuring output meta"
|
||||
assert self._outputs_meta is not None, (
|
||||
"Attempted to get_outputs_meta() without configuring output meta"
|
||||
)
|
||||
return self._outputs_meta
|
||||
|
||||
def _create_grad_send_info(
|
||||
@ -358,12 +358,12 @@ class _PipelineStageBase(ABC):
|
||||
prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs)
|
||||
|
||||
for info, tensor in zip(recv_infos, prev_stage_outputs):
|
||||
assert isinstance(
|
||||
tensor, torch.Tensor
|
||||
), f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
||||
assert isinstance(
|
||||
info, _RecvInfo
|
||||
), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo"
|
||||
assert isinstance(tensor, torch.Tensor), (
|
||||
f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
||||
)
|
||||
assert isinstance(info, _RecvInfo), (
|
||||
"set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo"
|
||||
)
|
||||
|
||||
# We don't need to do a data copy here, since we can directly pass the activation tensor reference from
|
||||
# one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve
|
||||
@ -376,9 +376,9 @@ class _PipelineStageBase(ABC):
|
||||
"""
|
||||
Returns the input grad tensors for this stage, which correspond to the stage inputs during forward.
|
||||
"""
|
||||
assert (
|
||||
self.has_backward
|
||||
), "can't steal_bwd_input if this stage doesn't have backward"
|
||||
assert self.has_backward, (
|
||||
"can't steal_bwd_input if this stage doesn't have backward"
|
||||
)
|
||||
assert not self.is_first, "can't get bwd output if this stage is first"
|
||||
|
||||
self._check_chunk_id(mb_index)
|
||||
@ -391,22 +391,22 @@ class _PipelineStageBase(ABC):
|
||||
Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
|
||||
Does not detach or set '_requires_grad'.
|
||||
"""
|
||||
assert isinstance(
|
||||
next_stage_bwd_outputs, tuple
|
||||
), f"Expected tuple, got {type(next_stage_bwd_outputs)}"
|
||||
assert isinstance(next_stage_bwd_outputs, tuple), (
|
||||
f"Expected tuple, got {type(next_stage_bwd_outputs)}"
|
||||
)
|
||||
|
||||
assert (
|
||||
self.has_backward
|
||||
), "can't set bwd input if this stage doesn't have backward"
|
||||
assert self.has_backward, (
|
||||
"can't set bwd input if this stage doesn't have backward"
|
||||
)
|
||||
assert not self.is_last, "can't set bwd input if this stage is last"
|
||||
recv_infos = self.grad_recv_info[mb_index]
|
||||
for info, tensor in zip(recv_infos, next_stage_bwd_outputs):
|
||||
assert isinstance(
|
||||
tensor, torch.Tensor
|
||||
), f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
||||
assert isinstance(
|
||||
info, _RecvInfo
|
||||
), f"Expected a recv info, got {type(info)}"
|
||||
assert isinstance(tensor, torch.Tensor), (
|
||||
f"expected tensor values as outputs from prev stage, got {type(tensor)}"
|
||||
)
|
||||
assert isinstance(info, _RecvInfo), (
|
||||
f"Expected a recv info, got {type(info)}"
|
||||
)
|
||||
info.buffer = tensor
|
||||
|
||||
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
|
||||
@ -1053,9 +1053,9 @@ class _PipelineStage(_PipelineStageBase):
|
||||
# If the input is a getitem, we need to go deeper
|
||||
arg_node = arg_node.args[0]
|
||||
|
||||
assert (
|
||||
arg_node.op == "call_module"
|
||||
), f"Expecting call_module, got {arg_node.op}"
|
||||
assert arg_node.op == "call_module", (
|
||||
f"Expecting call_module, got {arg_node.op}"
|
||||
)
|
||||
src_stage = self.get_stage_index_of_submod(arg_node.name)
|
||||
|
||||
# Create a receive buffer for this placeholder
|
||||
@ -1081,7 +1081,8 @@ class _PipelineStage(_PipelineStageBase):
|
||||
args_recv_info: list[InputInfo] = []
|
||||
# Filter out placeholder nodes from `self.submod` (a GraphModule)
|
||||
placeholders = filter( # type: ignore[var-annotated]
|
||||
lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr]
|
||||
lambda node: node.op == "placeholder", # type: ignore[arg-type]
|
||||
self.submod.graph.nodes, # type: ignore[arg-type,union-attr]
|
||||
)
|
||||
# `placeholders` are nodes internal to submod.
|
||||
# `self.node.args` are dependency nodes in the outer graph.
|
||||
@ -1300,9 +1301,9 @@ class PipelineStage(_PipelineStageBase):
|
||||
raise RuntimeError(
|
||||
"Failed to perform pipeline shape inference- are your inputs on the same device as your module?"
|
||||
) from e
|
||||
assert (
|
||||
output_args is not None
|
||||
), "If passing input_args, also pass output_args to override shape inference"
|
||||
assert output_args is not None, (
|
||||
"If passing input_args, also pass output_args to override shape inference"
|
||||
)
|
||||
self._configure_outputs_meta(
|
||||
(output_args,) if isinstance(output_args, torch.Tensor) else output_args
|
||||
)
|
||||
@ -1346,9 +1347,9 @@ class PipelineStage(_PipelineStageBase):
|
||||
)
|
||||
args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args)
|
||||
else:
|
||||
assert (
|
||||
len(args) == 0
|
||||
), "Can't supply input args for shape inference on non-first stage"
|
||||
assert len(args) == 0, (
|
||||
"Can't supply input args for shape inference on non-first stage"
|
||||
)
|
||||
objects = [None]
|
||||
logger.debug(
|
||||
"Shape inference: stage %s receiving from stage %s",
|
||||
|
@ -80,9 +80,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
|
||||
world_size = world_size_opt
|
||||
if rank != -1 or world_size != -1 or world_size_opt is None:
|
||||
query_dict = _query_to_dict(result.query)
|
||||
assert (
|
||||
"rank" not in query_dict and "world_size" not in query_dict
|
||||
), f"The url: {url} has node-specific arguments(rank, world_size) already."
|
||||
assert "rank" not in query_dict and "world_size" not in query_dict, (
|
||||
f"The url: {url} has node-specific arguments(rank, world_size) already."
|
||||
)
|
||||
if rank != -1:
|
||||
query_dict["rank"] = str(rank)
|
||||
if world_size != -1 or world_size_opt is None:
|
||||
|
@ -137,13 +137,13 @@ def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
|
||||
with _all_gather_dict_lock:
|
||||
if not worker_names:
|
||||
worker_names = _ALL_WORKER_NAMES
|
||||
assert (
|
||||
worker_name in worker_names
|
||||
), f"{worker_name} is not expected by leader."
|
||||
assert worker_name in worker_names, (
|
||||
f"{worker_name} is not expected by leader."
|
||||
)
|
||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||
assert (
|
||||
worker_name not in states.gathered_objects
|
||||
), f"{worker_name} reported intent sequence id {sequence_id} twice. "
|
||||
assert worker_name not in states.gathered_objects, (
|
||||
f"{worker_name} reported intent sequence id {sequence_id} twice. "
|
||||
)
|
||||
states.gathered_objects[worker_name] = obj
|
||||
if worker_names == set(states.gathered_objects.keys()):
|
||||
states.proceed_signal.set()
|
||||
@ -153,9 +153,9 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
||||
with _all_gather_dict_lock:
|
||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||
|
||||
assert (
|
||||
not states.proceed_signal.is_set()
|
||||
), f"Termination signal sequence id {sequence_id} got set twice."
|
||||
assert not states.proceed_signal.is_set(), (
|
||||
f"Termination signal sequence id {sequence_id} got set twice."
|
||||
)
|
||||
states.gathered_objects = objects_map
|
||||
states.proceed_signal.set()
|
||||
|
||||
@ -202,9 +202,9 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
|
||||
function blocks until all workers have received the gathered results.
|
||||
"""
|
||||
if not worker_names:
|
||||
assert (
|
||||
_ALL_WORKER_NAMES is not None
|
||||
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
|
||||
assert _ALL_WORKER_NAMES is not None, (
|
||||
"`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
|
||||
)
|
||||
worker_names = _ALL_WORKER_NAMES
|
||||
leader_name = min(worker_names)
|
||||
|
||||
@ -930,8 +930,7 @@ def _get_should_profile():
|
||||
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
|
||||
return (
|
||||
torch.autograd._profiler_enabled()
|
||||
and torch._C._autograd._profiler_type()
|
||||
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||
and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ def _to_device(device: DeviceType) -> torch.device:
|
||||
|
||||
|
||||
def _to_device_map(
|
||||
device_map: dict[DeviceType, DeviceType]
|
||||
device_map: dict[DeviceType, DeviceType],
|
||||
) -> dict[torch.device, torch.device]:
|
||||
full_device_map: dict[torch.device, torch.device] = {}
|
||||
reverse_map: dict[torch.device, torch.device] = {}
|
||||
@ -127,7 +127,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
||||
>>> options = TensorPipeRpcBackendOptions(
|
||||
>>> num_worker_threads=8,
|
||||
>>> device_maps={"worker1": {0: 1}}
|
||||
>>> # maps worker0's cuda:0 to worker1's cuda:1
|
||||
>>> # maps worker0's cuda:0 to worker1's cuda:1
|
||||
>>> )
|
||||
>>> options.set_device_map("worker1", {1: 2})
|
||||
>>> # maps worker0's cuda:1 to worker1's cuda:2
|
||||
|
@ -63,10 +63,14 @@ class _server_process_global_profile(profile):
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> x, y = torch.tensor(1), torch.tensor(2)
|
||||
>>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
|
||||
>>> outer_profile_rref = rpc.remote(
|
||||
... dst_worker_name, rpc._server_process_global_profile
|
||||
... )
|
||||
>>> outer_profile_rref.rpc_sync().__enter__()
|
||||
>>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
|
||||
>>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
|
||||
>>> inner_profile_rref = rpc.remote(
|
||||
... dst_worker_name, rpc._server_process_global_profile
|
||||
... )
|
||||
>>> inner_profile_rref.rpc_sync().__enter__()
|
||||
>>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
|
||||
>>> inner_profile_rref.rpc_sync().__exit__(None, None, None)
|
||||
|
@ -289,9 +289,9 @@ Important Notices
|
||||
|
||||
::
|
||||
|
||||
>>> # xdoctest: +SKIP("stub")
|
||||
>>> import torch.distributed as dist
|
||||
>>> dist.init_process_group(backend="gloo|nccl")
|
||||
>>> # xdoctest: +SKIP("stub")
|
||||
>>> import torch.distributed as dist
|
||||
>>> dist.init_process_group(backend="gloo|nccl")
|
||||
|
||||
3. In your training program, you can either use regular distributed functions
|
||||
or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
|
||||
@ -302,9 +302,9 @@ Important Notices
|
||||
::
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
model = torch.nn.parallel.DistributedDataParallel(model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[local_rank], output_device=local_rank
|
||||
)
|
||||
|
||||
Please ensure that ``device_ids`` argument is set to be the only GPU device id
|
||||
that your code will be operating on. This is generally the local rank of the
|
||||
@ -331,17 +331,18 @@ utility
|
||||
|
||||
::
|
||||
|
||||
def main():
|
||||
load_checkpoint(checkpoint_path)
|
||||
initialize()
|
||||
train()
|
||||
def main():
|
||||
load_checkpoint(checkpoint_path)
|
||||
initialize()
|
||||
train()
|
||||
|
||||
def train():
|
||||
for batch in iter(dataset):
|
||||
train_step(batch)
|
||||
|
||||
if should_checkpoint:
|
||||
save_checkpoint(checkpoint_path)
|
||||
def train():
|
||||
for batch in iter(dataset):
|
||||
train_step(batch)
|
||||
|
||||
if should_checkpoint:
|
||||
save_checkpoint(checkpoint_path)
|
||||
|
||||
9. (Recommended) On worker errors, this tool will summarize the details of the error
|
||||
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
|
||||
@ -353,17 +354,19 @@ utility
|
||||
|
||||
::
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
@record
|
||||
def main():
|
||||
# do train
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@record
|
||||
def main():
|
||||
# do train
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
""" # noqa: E501
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
@ -297,9 +297,9 @@ class DTensor(torch.Tensor):
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
|
||||
assert (
|
||||
flatten_spec is not None
|
||||
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||
assert flatten_spec is not None, (
|
||||
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||
)
|
||||
local_tensor = inner_tensors["_local_tensor"]
|
||||
spec, requires_grad = flatten_spec
|
||||
unflatten_tensor_meta = TensorMeta(
|
||||
@ -694,9 +694,7 @@ def distribute_tensor(
|
||||
xla_distribute_tensor,
|
||||
)
|
||||
|
||||
return xla_distribute_tensor(
|
||||
tensor, device_mesh, placements
|
||||
) # type:ignore[return-value]
|
||||
return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value]
|
||||
except ImportError as e:
|
||||
msg = "To use DTensor API with xla, you must install the torch_xla package!"
|
||||
raise ImportError(msg) from e
|
||||
@ -930,7 +928,9 @@ def distribute_module(
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
|
||||
module.register_forward_pre_hook(
|
||||
lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg]
|
||||
)
|
||||
elif num_args == 3:
|
||||
# input_fn takes in module, inputs, device mesh
|
||||
module.register_forward_pre_hook(
|
||||
@ -990,9 +990,9 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
||||
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
|
||||
|
||||
# check device_mesh againts placements
|
||||
assert device_mesh.ndim == len(
|
||||
placements
|
||||
), "mesh dimension does not match the length of placements"
|
||||
assert device_mesh.ndim == len(placements), (
|
||||
"mesh dimension does not match the length of placements"
|
||||
)
|
||||
|
||||
assert kwargs["layout"] == torch.strided, "layout value not supported!"
|
||||
torch_stride = torch._prims_common.make_contiguous_strides_for(size)
|
||||
|
@ -75,7 +75,8 @@ def found_inf_reduce_handler(
|
||||
) -> None:
|
||||
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
||||
local_tensor_args = pytree.tree_unflatten(
|
||||
cast(list[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type]
|
||||
cast(list[object], op_info.local_args),
|
||||
op_info.args_tree_spec, # type: ignore[arg-type]
|
||||
)
|
||||
local_tensor_args = cast(tuple[object, ...], local_tensor_args)
|
||||
op_call(*local_tensor_args, **op_info.local_kwargs)
|
||||
@ -200,8 +201,9 @@ class OpDispatcher:
|
||||
# did not already construct one
|
||||
random._rng_tracker = random.OffsetBasedRNGTracker(mesh)
|
||||
|
||||
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
|
||||
torch.Tensor, local_tensor_args[0]
|
||||
first_arg, first_local_arg = (
|
||||
cast(dtensor.DTensor, args[0]),
|
||||
cast(torch.Tensor, local_tensor_args[0]),
|
||||
)
|
||||
rng_context = (
|
||||
random._rng_tracker._distribute_region(first_arg._spec)
|
||||
@ -422,18 +424,18 @@ class OpDispatcher:
|
||||
def wrap(res: object, spec: OutputSpecType) -> object:
|
||||
if isinstance(res, torch.Tensor):
|
||||
if spec is not None:
|
||||
assert isinstance(
|
||||
spec, DTensorSpec
|
||||
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
||||
assert isinstance(spec, DTensorSpec), (
|
||||
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
||||
)
|
||||
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
|
||||
else:
|
||||
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
|
||||
assert res.ndim == 0, "output tensor should be scalar!"
|
||||
return res
|
||||
elif isinstance(res, (list, tuple)):
|
||||
assert spec is not None and isinstance(
|
||||
spec, (list, tuple)
|
||||
), f"output spec does not match with output! Expected list/tuple, got {spec}."
|
||||
assert spec is not None and isinstance(spec, (list, tuple)), (
|
||||
f"output spec does not match with output! Expected list/tuple, got {spec}."
|
||||
)
|
||||
res_list = []
|
||||
for e, s in zip(res, spec):
|
||||
res_list.append(OpDispatcher.wrap(e, s))
|
||||
|
@ -152,9 +152,9 @@ class OpStrategy(StrategyType):
|
||||
if isinstance(output_spec, DTensorSpec):
|
||||
return output_spec.mesh.shape
|
||||
else:
|
||||
assert isinstance(
|
||||
output_spec, tuple
|
||||
), "found no DTensorSpec in the OpStrategy!"
|
||||
assert isinstance(output_spec, tuple), (
|
||||
"found no DTensorSpec in the OpStrategy!"
|
||||
)
|
||||
assert output_spec[0] is not None
|
||||
return output_spec[0].mesh.shape
|
||||
|
||||
|
@ -63,9 +63,9 @@ class EinsumDims:
|
||||
if is_batch_dim:
|
||||
batch_dims.append(dim_char)
|
||||
else:
|
||||
assert (
|
||||
len(input_dims) == 2
|
||||
), "free dimension only supported for two inputs!"
|
||||
assert len(input_dims) == 2, (
|
||||
"free dimension only supported for two inputs!"
|
||||
)
|
||||
lhs, rhs = input_dims
|
||||
if dim_char in lhs:
|
||||
lhs_out_only_dims.append(dim_char)
|
||||
|
@ -89,9 +89,9 @@ class _MaskPartial(Partial):
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert (
|
||||
self.offset_shape is not None
|
||||
), "offset_shape needs to be set for _MaskPartial"
|
||||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for _MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
|
@ -994,9 +994,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
|
||||
)
|
||||
output_specs_list.append(weight_out_spec if output_mask[1] else None)
|
||||
else:
|
||||
assert (
|
||||
output_mask[1] is False
|
||||
), "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
|
||||
assert output_mask[1] is False, (
|
||||
"output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
|
||||
)
|
||||
output_specs_list.append(None)
|
||||
|
||||
# arg: bias
|
||||
@ -1020,9 +1020,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
|
||||
)
|
||||
output_specs_list.append(bias_out_spec if output_mask[2] else None)
|
||||
else:
|
||||
assert (
|
||||
output_mask[2] is False
|
||||
), "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
|
||||
assert output_mask[2] is False, (
|
||||
"output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
|
||||
)
|
||||
output_specs_list.append(None)
|
||||
|
||||
out_tuple_strategy.strategies.append(
|
||||
|
@ -155,9 +155,9 @@ def _scaled_mm_like_strategy(
|
||||
assert isinstance(scale_mat2_strategy, OpStrategy)
|
||||
# TODO: add support for these later
|
||||
assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias"
|
||||
assert (
|
||||
scale_result_strategy is None
|
||||
), "_scaled_mm on DTensors doesn't support scale_result"
|
||||
assert scale_result_strategy is None, (
|
||||
"_scaled_mm on DTensors doesn't support scale_result"
|
||||
)
|
||||
# generate all possible strategies for mm
|
||||
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
|
||||
# filter out invalid strategies and associate costs
|
||||
|
@ -445,9 +445,9 @@ def pointwise_strategy(
|
||||
|
||||
followed_strategy = op_schema.args_schema[max_shards_strategy_index]
|
||||
|
||||
assert isinstance(
|
||||
followed_strategy, OpStrategy
|
||||
), f"no strategy to follow for {op_schema}!"
|
||||
assert isinstance(followed_strategy, OpStrategy), (
|
||||
f"no strategy to follow for {op_schema}!"
|
||||
)
|
||||
return common_pointwise_strategy(
|
||||
mesh, op_schema.args_schema, followed_strategy, linearity
|
||||
)
|
||||
|
@ -254,9 +254,9 @@ def dim_movedim(
|
||||
|
||||
def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
|
||||
sizes = normalize_sizes(sizes)
|
||||
assert (
|
||||
len(sizes) >= ndim
|
||||
), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
|
||||
assert len(sizes) >= ndim, (
|
||||
f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
|
||||
)
|
||||
pad = len(sizes) - ndim
|
||||
return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
|
||||
Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
|
||||
@ -275,9 +275,9 @@ def infer_size(total_size: int, sizes: Shape) -> Shape:
|
||||
if infers:
|
||||
size = -size
|
||||
missing_size = total_size // size
|
||||
assert (
|
||||
total_size % size == 0
|
||||
), f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
|
||||
assert total_size % size == 0, (
|
||||
f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
|
||||
)
|
||||
return tuple(s if s != -1 else missing_size for s in sizes)
|
||||
assert size == total_size, f"sizes do not match {total_size} vs {size}"
|
||||
return sizes
|
||||
@ -538,9 +538,9 @@ def propagate_shape_and_sharding(
|
||||
for size, shard in zip(mesh_sizes, input_src_placements):
|
||||
if isinstance(shard, Shard) and shard.dim == in_dim:
|
||||
submesh_size *= size
|
||||
assert (
|
||||
out_size % submesh_size == 0
|
||||
), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
|
||||
assert out_size % submesh_size == 0, (
|
||||
f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
|
||||
)
|
||||
|
||||
# we will only shard our first component of the split
|
||||
return in_dim if cmd.split_id == 0 else None
|
||||
|
@ -45,7 +45,7 @@ def register_prop_rule(
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def wrapper(
|
||||
impl: Callable[[OpSchema], OutputSharding]
|
||||
impl: Callable[[OpSchema], OutputSharding],
|
||||
) -> Callable[[OpSchema], OutputSharding]:
|
||||
overloads = op if isinstance(op, list) else [op]
|
||||
for overload in overloads:
|
||||
@ -102,7 +102,7 @@ def register_op_strategy(
|
||||
|
||||
|
||||
def as_list(
|
||||
x: Union[list[object], object]
|
||||
x: Union[list[object], object],
|
||||
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
||||
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
||||
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
|
||||
|
@ -231,9 +231,9 @@ def redistribute_local_tensor(
|
||||
local_tensor, device_mesh, i, my_coordinate[i]
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
current.is_shard()
|
||||
), f"Current placement should be shard but found {current}"
|
||||
assert current.is_shard(), (
|
||||
f"Current placement should be shard but found {current}"
|
||||
)
|
||||
shard_spec = cast(Shard, current)
|
||||
if shard_spec.dim != target_placement.dim:
|
||||
new_local_tensor = shard_spec._to_new_shard_dim(
|
||||
|
@ -487,9 +487,9 @@ class ShardingPropagator:
|
||||
|
||||
strategy_costs: list[float] = []
|
||||
for strtg in strategy.strategies:
|
||||
assert (
|
||||
strtg.redistribute_cost is not None
|
||||
), "must set redistribute cost each strategy!"
|
||||
assert strtg.redistribute_cost is not None, (
|
||||
"must set redistribute cost each strategy!"
|
||||
)
|
||||
redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
|
||||
strategy_costs.append(redistribute_cost)
|
||||
|
||||
|
@ -73,9 +73,9 @@ def compute_local_shape_and_global_offset(
|
||||
if isinstance(placement, Shard):
|
||||
shard_dim = placement.dim
|
||||
local_offset = [0] * len(global_shape)
|
||||
assert shard_dim < len(
|
||||
local_shape
|
||||
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||
assert shard_dim < len(local_shape), (
|
||||
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||
)
|
||||
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
||||
local_shape[shard_dim],
|
||||
mesh_dim_size,
|
||||
@ -141,16 +141,15 @@ def compute_local_shape_and_global_offset(
|
||||
|
||||
if isinstance(placement, _StridedShard):
|
||||
strided_part_seen[shard_dim] = True
|
||||
shard_idx_stride_by_mesh_dim[shard_dim][
|
||||
idx
|
||||
] = num_shards_by_tensor_dim[shard_dim] // (
|
||||
placement.split_factor * mesh_dim_size
|
||||
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
|
||||
num_shards_by_tensor_dim[shard_dim]
|
||||
// (placement.split_factor * mesh_dim_size)
|
||||
)
|
||||
else:
|
||||
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
|
||||
shard_idx_stride_by_mesh_dim[shard_dim][
|
||||
idx
|
||||
] = num_shards_by_tensor_dim[shard_dim]
|
||||
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
|
||||
num_shards_by_tensor_dim[shard_dim]
|
||||
)
|
||||
|
||||
shard_idx = [
|
||||
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
|
||||
@ -205,9 +204,9 @@ def compute_global_tensor_info(
|
||||
)
|
||||
shard_dim = shard_placement.dim
|
||||
|
||||
assert (
|
||||
shard_dim < tensor.ndim
|
||||
), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
|
||||
assert shard_dim < tensor.ndim, (
|
||||
f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
|
||||
)
|
||||
|
||||
local_dim_size = tensor_shape[shard_dim]
|
||||
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
|
||||
|
@ -283,9 +283,9 @@ class CommDebugMode(TorchDispatchMode):
|
||||
"module_type" in self.advanced_module_tracker.module_helper_dict[fqn]
|
||||
and include_module_data
|
||||
):
|
||||
json_dict[
|
||||
"module_type"
|
||||
] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
|
||||
json_dict["module_type"] = (
|
||||
self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
|
||||
)
|
||||
|
||||
if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
|
||||
for (
|
||||
@ -659,9 +659,9 @@ class CommDebugMode(TorchDispatchMode):
|
||||
operation_dict["is_bw"] = self.advanced_module_tracker.is_bw
|
||||
|
||||
# tracks if the operation is part of activation checkpointing
|
||||
operation_dict[
|
||||
"is_activation_checkpointing"
|
||||
] = self.advanced_module_tracker.activation_checkpointing
|
||||
operation_dict["is_activation_checkpointing"] = (
|
||||
self.advanced_module_tracker.activation_checkpointing
|
||||
)
|
||||
|
||||
if any(t == DTensor for t in types):
|
||||
for ele in args:
|
||||
|
@ -108,9 +108,9 @@ def _compute_local_shape_and_global_offset(
|
||||
if isinstance(placement, Shard):
|
||||
shard_dim = placement.dim
|
||||
local_offset = [0] * len(global_shape)
|
||||
assert shard_dim < len(
|
||||
local_shape
|
||||
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||
assert shard_dim < len(local_shape), (
|
||||
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||
)
|
||||
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
||||
local_shape[shard_dim],
|
||||
mesh_dim_size,
|
||||
|
@ -2,6 +2,7 @@
|
||||
To run the example, use the following command:
|
||||
torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Callable, Union
|
||||
|
@ -6,6 +6,7 @@ with intermediate activations sharded across mutliple GPUs via DTensor
|
||||
To run the example, use the following command:
|
||||
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
The following example demonstrates how to represent torchrec's embedding
|
||||
sharding with the DTensor API.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
@ -253,22 +253,18 @@ class _AttentionOp(Protocol):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
**kwargs: object,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
...
|
||||
) -> tuple[torch.Tensor, ...]: ...
|
||||
|
||||
|
||||
class _RingRotater(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None:
|
||||
...
|
||||
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None:
|
||||
...
|
||||
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def next_buffer(self) -> torch.Tensor:
|
||||
...
|
||||
def next_buffer(self) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class _AllToAllRotater(_RingRotater):
|
||||
@ -1097,15 +1093,13 @@ class _LoadBalancer(ABC):
|
||||
@abstractmethod
|
||||
def shard(
|
||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def unshard(
|
||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class _SequentialSharder(_LoadBalancer):
|
||||
@ -1147,9 +1141,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
|
||||
def shard(
|
||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
cls.ROUND_ROBIN_CYCLE == 2
|
||||
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
||||
assert cls.ROUND_ROBIN_CYCLE == 2, (
|
||||
"The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
||||
)
|
||||
cp_world_size = mesh.size()
|
||||
cp_rank = mesh.get_local_rank()
|
||||
assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0
|
||||
@ -1163,9 +1157,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
|
||||
def unshard(
|
||||
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
cls.ROUND_ROBIN_CYCLE == 2
|
||||
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
||||
assert cls.ROUND_ROBIN_CYCLE == 2, (
|
||||
"The current implementation only works if ROUND_ROBIN_CYCLE is 2."
|
||||
)
|
||||
buffer = buffer.contiguous()
|
||||
cp_world_size = mesh.size()
|
||||
|
||||
|
@ -113,9 +113,15 @@ def local_map(
|
||||
>>> device_mesh=device_mesh,
|
||||
>>> )
|
||||
>>>
|
||||
>>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor
|
||||
>>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor
|
||||
>>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
|
||||
>>> W_dt = distribute_tensor(
|
||||
... W, device_mesh, (col_wise)
|
||||
... ) # col-wisely sharded W tensor
|
||||
>>> X_dt = distribute_tensor(
|
||||
... X, device_mesh, (row_wise)
|
||||
... ) # row-wisely sharded X tensor
|
||||
>>> Y_dt = local_mm_allreduce_forward(
|
||||
... device_mesh, W_dt, X_dt
|
||||
... ) # apply local_mm_allreduce_forward to DTensors
|
||||
|
||||
.. note:: This API is currently experimental and subject to change
|
||||
"""
|
||||
@ -151,9 +157,9 @@ def local_map(
|
||||
)
|
||||
if in_placements is not None:
|
||||
spec = in_placements[idx]
|
||||
assert (
|
||||
spec is not None
|
||||
), f"DTensor input {arg} expects placements but received {spec}!"
|
||||
assert spec is not None, (
|
||||
f"DTensor input {arg} expects placements but received {spec}!"
|
||||
)
|
||||
|
||||
if not isinstance(spec, tuple):
|
||||
spec = tuple(spec)
|
||||
@ -208,17 +214,17 @@ def local_map(
|
||||
)
|
||||
for out, spec in zip(flat_out, out_placements_tuple):
|
||||
if isinstance(out, torch.Tensor):
|
||||
assert not isinstance(
|
||||
out, DTensor
|
||||
), f"torch.Tensor output expected but received {type(out)}: {out}"
|
||||
assert not isinstance(out, DTensor), (
|
||||
f"torch.Tensor output expected but received {type(out)}: {out}"
|
||||
)
|
||||
|
||||
flat_dist_out.append(
|
||||
DTensor.from_local(out, device_mesh, spec, run_check=False)
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
spec is None
|
||||
), f"Non-tensor output {out} expects None placements but received {spec}!"
|
||||
assert spec is None, (
|
||||
f"Non-tensor output {out} expects None placements but received {spec}!"
|
||||
)
|
||||
|
||||
flat_dist_out.append(out)
|
||||
|
||||
|
@ -188,9 +188,14 @@ def _mark_sharding(
|
||||
"""
|
||||
Mark the sharding strategy for each node in the graph module.
|
||||
"""
|
||||
placement_strategies: dict[
|
||||
Node, PlacementStrategy
|
||||
] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
|
||||
placement_strategies: dict[Node, PlacementStrategy] = (
|
||||
_mark_tensor_parallel_shardings(
|
||||
gm,
|
||||
graph_signature,
|
||||
mesh,
|
||||
parameter_placements,
|
||||
)
|
||||
)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
@ -202,9 +207,9 @@ def _mark_sharding(
|
||||
elif node.op == "call_function":
|
||||
if node.target == operator.getitem:
|
||||
input_nodes = node.all_input_nodes
|
||||
assert (
|
||||
len(input_nodes) == 1
|
||||
), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
|
||||
assert len(input_nodes) == 1, (
|
||||
f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
|
||||
)
|
||||
arg_strategy = placement_strategies[input_nodes[0]]
|
||||
placement_strategies[node] = _create_placement_strategy(
|
||||
node,
|
||||
|
@ -328,7 +328,9 @@ class DTensorExtensions(FSDPExtensions):
|
||||
self.device_handle = device_handle
|
||||
# we have to use the dynamo disable this way to disable dynamo as the decorater way would
|
||||
# trigger build failure with torch deploy...
|
||||
self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign]
|
||||
self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign]
|
||||
self.post_unflatten_transform
|
||||
)
|
||||
|
||||
def pre_flatten_transform(
|
||||
self,
|
||||
|
@ -64,9 +64,7 @@ def input_reshard(
|
||||
return module
|
||||
|
||||
|
||||
def _pack_hook_tp(
|
||||
mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor
|
||||
) -> Any: # noqa: D401
|
||||
def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401
|
||||
"""Hook function called after FWD to shard input."""
|
||||
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
|
||||
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
|
||||
@ -84,9 +82,7 @@ def _pack_hook_tp(
|
||||
return x
|
||||
|
||||
|
||||
def _unpack_hook_tp(
|
||||
mesh: DeviceMesh, input_reshard_dim: int, x: Any
|
||||
) -> torch.Tensor: # noqa: D401
|
||||
def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401
|
||||
"""Hook function called before activation recomputing in BWD to restore input."""
|
||||
if (
|
||||
isinstance(x, DTensor)
|
||||
|
@ -38,8 +38,7 @@ class ParallelStyle(ABC):
|
||||
src_data_rank: Optional[int] = 0
|
||||
|
||||
@abstractmethod
|
||||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
||||
...
|
||||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...
|
||||
|
||||
|
||||
class ColwiseParallel(ParallelStyle):
|
||||
@ -467,19 +466,21 @@ class PrepareModuleInput(ParallelStyle):
|
||||
)
|
||||
self.use_local_output = use_local_output
|
||||
if self.input_layouts is not None:
|
||||
assert (
|
||||
self.desired_input_layouts is not None
|
||||
), "desired module inputs should not be None!"
|
||||
assert len(self.input_layouts) == len(
|
||||
self.desired_input_layouts
|
||||
), "input_layouts and desired_input_layouts should have same length!"
|
||||
assert self.desired_input_layouts is not None, (
|
||||
"desired module inputs should not be None!"
|
||||
)
|
||||
assert len(self.input_layouts) == len(self.desired_input_layouts), (
|
||||
"input_layouts and desired_input_layouts should have same length!"
|
||||
)
|
||||
self.with_kwargs = input_kwarg_layouts is not None
|
||||
self.input_kwarg_layouts = input_kwarg_layouts or {}
|
||||
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
|
||||
if self.with_kwargs:
|
||||
assert len(self.input_kwarg_layouts) == len(
|
||||
self.desired_input_kwarg_layouts
|
||||
), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
|
||||
), (
|
||||
"input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
|
||||
)
|
||||
|
||||
def _prepare_input_arg(
|
||||
self,
|
||||
@ -494,9 +495,9 @@ class PrepareModuleInput(ParallelStyle):
|
||||
# assert inp.placements[0] == input_layout
|
||||
dt_inp = input
|
||||
else:
|
||||
assert isinstance(
|
||||
input, torch.Tensor
|
||||
), "expecting input to be a torch.Tensor!"
|
||||
assert isinstance(input, torch.Tensor), (
|
||||
"expecting input to be a torch.Tensor!"
|
||||
)
|
||||
dt_inp = DTensor.from_local(
|
||||
input, mesh, (input_layout,), run_check=False
|
||||
)
|
||||
@ -517,9 +518,9 @@ class PrepareModuleInput(ParallelStyle):
|
||||
if len(inputs) != len(self.input_layouts):
|
||||
raise ValueError("module inputs and input_layouts should have same length!")
|
||||
|
||||
assert (
|
||||
self.desired_input_layouts is not None
|
||||
), "desired module inputs should not be None!"
|
||||
assert self.desired_input_layouts is not None, (
|
||||
"desired module inputs should not be None!"
|
||||
)
|
||||
for inp, input_layout, desired_layout in zip(
|
||||
inputs, self.input_layouts, self.desired_input_layouts
|
||||
):
|
||||
@ -551,7 +552,9 @@ class PrepareModuleInput(ParallelStyle):
|
||||
with_kwargs=True,
|
||||
) # type: ignore[misc]
|
||||
else:
|
||||
module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg]
|
||||
module.register_forward_pre_hook(
|
||||
lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
|
||||
) # type: ignore[misc, call-arg]
|
||||
return module
|
||||
|
||||
|
||||
@ -611,9 +614,9 @@ class PrepareModuleOutput(ParallelStyle):
|
||||
else desired_output_layouts
|
||||
)
|
||||
self.use_local_output = use_local_output
|
||||
assert len(self.output_layouts) == len(
|
||||
self.desired_output_layouts
|
||||
), "output_layouts and desired_output_layouts should have same length!"
|
||||
assert len(self.output_layouts) == len(self.desired_output_layouts), (
|
||||
"output_layouts and desired_output_layouts should have same length!"
|
||||
)
|
||||
|
||||
def _prepare_out_fn(self, outputs, device_mesh):
|
||||
prepared_outputs = []
|
||||
@ -649,5 +652,7 @@ class PrepareModuleOutput(ParallelStyle):
|
||||
return tuple(prepared_outputs)
|
||||
|
||||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
||||
module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg]
|
||||
module.register_forward_hook(
|
||||
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
|
||||
) # type: ignore[misc, call-arg]
|
||||
return module
|
||||
|
@ -83,9 +83,9 @@ class Shard(Placement):
|
||||
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
|
||||
This is because collectives usually require equal size tensor inputs
|
||||
"""
|
||||
assert (
|
||||
self.dim <= tensor.ndim
|
||||
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
||||
assert self.dim <= tensor.ndim, (
|
||||
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
||||
)
|
||||
|
||||
# chunk tensor over dimension `dim` into n slices
|
||||
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
|
||||
@ -468,9 +468,9 @@ class _StridedShard(Shard):
|
||||
"""
|
||||
TODO: currently _StridedShard does not support padding
|
||||
"""
|
||||
assert (
|
||||
self.dim <= tensor.ndim
|
||||
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
||||
assert self.dim <= tensor.ndim, (
|
||||
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
||||
)
|
||||
|
||||
total_split = num_chunks * self.split_factor
|
||||
assert tensor.size(self.dim) % total_split == 0, (
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user