[BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
This commit is contained in:
Xuehai Pan
2025-02-28 11:10:58 +08:00
committed by PyTorch MergeBot
parent 4e160d5fd9
commit 995df34b19
143 changed files with 920 additions and 774 deletions

View File

@ -59,7 +59,6 @@ USE_BLACK_FILELIST = re.compile(
# torch/[a-c]*/** # torch/[a-c]*/**
"torch/[a-c]*/**", "torch/[a-c]*/**",
# torch/d*/** # torch/d*/**
"torch/d*/**",
# torch/[e-n]*/** # torch/[e-n]*/**
"torch/[e-n]*/**", "torch/[e-n]*/**",
# torch/optim/** # torch/optim/**

View File

@ -36,11 +36,9 @@ _M = TypeVar("_M", nn.Module, list[nn.Module])
class _ContractFn(Protocol, Generic[_P, _T, _TState]): class _ContractFn(Protocol, Generic[_P, _T, _TState]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
...
def state(self, module: nn.Module) -> _TState: def state(self, module: nn.Module) -> _TState: ...
...
def contract( def contract(
@ -92,7 +90,7 @@ def contract(
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
@wraps(state_cls) # type: ignore[arg-type] @wraps(state_cls) # type: ignore[arg-type]
def inner( def inner(
func: Callable[Concatenate[_M, _P], _M] func: Callable[Concatenate[_M, _P], _M],
) -> _ContractFn[Concatenate[_M, _P], _M, _TState]: ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
@wraps(func) @wraps(func)
def wrapper( def wrapper(
@ -232,9 +230,7 @@ def contract(
return module.__dict__.setdefault( # type: ignore[call-overload] return module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY, STATE_KEY,
{}, # TODO(@yhcharles): this is a temporary fix, need a better way {}, # TODO(@yhcharles): this is a temporary fix, need a better way
).get( ).get(func) # type: ignore[call-overload]
func
) # type: ignore[call-overload]
wrapper.state = get_state # type: ignore[attr-defined] wrapper.state = get_state # type: ignore[attr-defined]

View File

@ -274,9 +274,9 @@ def reduce_scatter_tensor(
group_name = _resolve_group_name(group, tag) group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name) group_size = c10d._get_group_size_by_name(group_name)
assert ( assert self.size(scatter_dim) % group_size == 0, (
self.size(scatter_dim) % group_size == 0 f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" )
if scatter_dim != 0: if scatter_dim != 0:
tensor_list = torch.chunk(self, group_size, dim=scatter_dim) tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
self = torch.cat(tensor_list) self = torch.cat(tensor_list)
@ -313,9 +313,9 @@ def reduce_scatter_tensor_autograd(
group_name = _resolve_group_name(group, tag) group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name) group_size = c10d._get_group_size_by_name(group_name)
assert ( assert self.size(scatter_dim) % group_size == 0, (
self.size(scatter_dim) % group_size == 0 f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" )
if scatter_dim != 0: if scatter_dim != 0:
tensor_list = torch.chunk(self, group_size, dim=scatter_dim) tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
self = torch.cat(tensor_list) self = torch.cat(tensor_list)
@ -414,9 +414,9 @@ def reduce_scatter_tensor_coalesced(
assert len(scatter_dim) == len(inputs) assert len(scatter_dim) == len(inputs)
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
assert ( assert tensor.size(dim) % group_size == 0, (
tensor.size(dim) % group_size == 0 f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" )
if dim != 0: if dim != 0:
tensor_list = torch.chunk(tensor, group_size, dim=dim) tensor_list = torch.chunk(tensor, group_size, dim=dim)
inputs[idx] = torch.cat(tensor_list) inputs[idx] = torch.cat(tensor_list)
@ -574,6 +574,7 @@ class AsyncCollectiveTensor(torch.Tensor):
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size) tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
return _maybe_wrap_tensor(tensor) return _maybe_wrap_tensor(tensor)
""" """
elem: torch.Tensor elem: torch.Tensor
completed: bool completed: bool
@ -726,9 +727,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
group_size = len(rankset) group_size = len(rankset)
tag = tag or c10d._get_group_tag(group) tag = tag or c10d._get_group_tag(group)
elif isinstance(group, DeviceMesh): elif isinstance(group, DeviceMesh):
assert ( assert group.ndim == 1, (
group.ndim == 1 "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" )
# TODO: it should run collective in the whole mesh instead of dim 0 # TODO: it should run collective in the whole mesh instead of dim 0
tag, rankset, _ = group._dim_group_infos[0] tag, rankset, _ = group._dim_group_infos[0]
group_size = len(rankset) group_size = len(rankset)
@ -763,9 +764,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
elif isinstance(group, str): elif isinstance(group, str):
return group return group
elif isinstance(group, DeviceMesh): elif isinstance(group, DeviceMesh):
assert ( assert group.ndim == 1, (
group.ndim == 1 "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" )
return group._dim_group_infos[0][2] return group._dim_group_infos[0][2]
elif isinstance(group, tuple): elif isinstance(group, tuple):
if ( if (
@ -837,11 +838,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
return y return y
@torch.compile(fullgraph=True) @torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y): def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y) torch.ops.c10d_functional.wait_tensor(y)
return y * y return y * y
x = torch.ones(1280, 1280, device="cuda") + self.rank x = torch.ones(1280, 1280, device="cuda") + self.rank
# the context manager ensures that `wait_tensor(y)` will wait on the correct work object # the context manager ensures that `wait_tensor(y)` will wait on the correct work object
with allow_inflight_collective_as_graph_input_ctx(): with allow_inflight_collective_as_graph_input_ctx():
@ -1057,9 +1060,9 @@ def all_gather_tensor_inplace(
tag: str = "", tag: str = "",
gather_dim: int = 0, gather_dim: int = 0,
): ):
assert ( assert not async_op, (
not async_op "Can't remap async version of inplace op to functional collective"
), "Can't remap async version of inplace op to functional collective" )
group = group or dist.group.WORLD group = group or dist.group.WORLD
assert group is not None assert group is not None
@ -1076,9 +1079,9 @@ def reduce_scatter_tensor_inplace(
scatter_dim: int = 0, scatter_dim: int = 0,
tag: str = "", tag: str = "",
): ):
assert ( assert not async_op, (
not async_op "Can't remap async version of inplace op to functional collective"
), "Can't remap async version of inplace op to functional collective" )
group = group or dist.group.WORLD group = group or dist.group.WORLD
assert group is not None assert group is not None
@ -1105,9 +1108,9 @@ def all_reduce_inplace(
async_op: bool = False, async_op: bool = False,
tag: str = "", tag: str = "",
): ):
assert ( assert not async_op, (
not async_op "Can't remap async version of inplace op to functional collective"
), "Can't remap async version of inplace op to functional collective" )
group = group or dist.group.WORLD group = group or dist.group.WORLD
assert group is not None assert group is not None
@ -1124,9 +1127,9 @@ def all_to_all_inplace(
async_op=False, async_op=False,
tag: str = "", tag: str = "",
): ):
assert ( assert not async_op, (
not async_op "Can't remap async version of inplace op to functional collective"
), "Can't remap async version of inplace op to functional collective" )
group = group or dist.group.WORLD group = group or dist.group.WORLD
assert group is not None assert group is not None
@ -1149,12 +1152,12 @@ def all_gather_inplace(
async_op=False, async_op=False,
tag: str = "", tag: str = "",
): ):
assert ( assert not async_op, (
not async_op "Can't remap async version of inplace op to functional collective"
), "Can't remap async version of inplace op to functional collective" )
assert all( assert all(t.size(0) == tensor.size(0) for t in tensor_list), (
t.size(0) == tensor.size(0) for t in tensor_list "Remapping variable size all_gather is not yet supported"
), "Remapping variable size all_gather is not yet supported" )
group = group or dist.group.WORLD group = group or dist.group.WORLD
assert group is not None assert group is not None

View File

@ -592,7 +592,9 @@ class ShardedTensor(ShardedTensorBase):
assert ( assert (
isinstance(device, torch.device) isinstance(device, torch.device)
and device.index == torch.cuda.current_device() and device.index == torch.cuda.current_device()
), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" ), (
"""Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
)
current_device = torch.device(torch.cuda.current_device()) current_device = torch.device(torch.cuda.current_device())
# returns a copy of ShardedTensor on CUDA current device # returns a copy of ShardedTensor on CUDA current device
@ -831,7 +833,9 @@ class ShardedTensor(ShardedTensorBase):
"rank:1/cuda:1", "rank:1/cuda:1",
], ],
) )
>>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4]) >>> st = ShardedTensor._init_from_local_tensor(
... local_tensor, sharding_spec, [2, 4]
... )
>>> st >>> st
ShardedTensor( ShardedTensor(
ShardedTensorMetadata( ShardedTensorMetadata(

View File

@ -219,9 +219,7 @@ def reshard_local_shard(
output_tensor_size = list(st_size) output_tensor_size = list(st_size)
output_tensor_size[current_sharding_dim] = sharded_dim_size output_tensor_size[current_sharding_dim] = sharded_dim_size
output_tensor_size[reshard_dim] = input_split_sizes[current_rank] output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
output_tensor_list[ output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index]
placement.rank()
] = torch.empty( # type: ignore[union-attr, index]
output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
) )
indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type] indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type]

View File

@ -16,6 +16,6 @@ with warnings.catch_warnings():
stacklevel=2, stacklevel=2,
) )
sys.modules[ sys.modules["torch.distributed._sharded_tensor"] = (
"torch.distributed._sharded_tensor" torch.distributed._shard.sharded_tensor
] = torch.distributed._shard.sharded_tensor )

View File

@ -67,7 +67,7 @@ def _all_gather_sharded_tensor(
class CompanionMismatch(Exception): class CompanionMismatch(Exception):
... pass
def _iterate_state_dict( def _iterate_state_dict(
@ -409,9 +409,9 @@ def _create_cpu_state_dict(
def unpin_memory(t): def unpin_memory(t):
succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr())) succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
assert ( assert succ == 0, (
succ == 0 f"Unpinning shared memory failed with error-code: {succ}"
), f"Unpinning shared memory failed with error-code: {succ}" )
weakref.finalize(t, unpin_memory, t) weakref.finalize(t, unpin_memory, t)
succ = int( succ = int(
@ -421,9 +421,9 @@ def _create_cpu_state_dict(
1, # lines up with 'cudaHostRegisterPortable' 1, # lines up with 'cudaHostRegisterPortable'
) )
) )
assert ( assert succ == 0, (
succ == 0 f"Pinning shared memory failed with error-code: {succ}"
), f"Pinning shared memory failed with error-code: {succ}" )
return t return t
elif pin_memory: elif pin_memory:
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()

View File

@ -1525,8 +1525,7 @@ if TYPE_CHECKING:
@overload @overload
def empty( def empty(
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None *size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
) -> torch.Tensor: ) -> torch.Tensor: ...
...
@overload @overload
@ -1535,8 +1534,7 @@ def empty(
*, *,
dtype: Optional[_dtype] = None, dtype: Optional[_dtype] = None,
device: Optional[_device] = None, device: Optional[_device] = None,
) -> torch.Tensor: ) -> torch.Tensor: ...
...
def empty( # type: ignore[misc] def empty( # type: ignore[misc]

View File

@ -6,6 +6,7 @@ we keep the old import path starts with `_tensor` for
backward compatibility. We will remove this folder once backward compatibility. We will remove this folder once
we resolve all the BC issues. we resolve all the BC issues.
""" """
import sys import sys
from importlib import import_module from importlib import import_module

View File

@ -153,7 +153,7 @@ class FSDPMemTracker(MemTracker):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
fmt.display_snapshot("peak") fmt.display_snapshot("peak")
fmt.display_modulewise_snapshots(depth = 3, units = "MB") fmt.display_modulewise_snapshots(depth=3, units="MB")
""" """

View File

@ -379,7 +379,7 @@ class MemTracker(TorchDispatchMode):
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
mt.display_snapshot("peak") mt.display_snapshot("peak")
mt.display_modulewise_snapshots(depth = 3, units = "MiB") mt.display_modulewise_snapshots(depth=3, units="MiB")
Known Limitations: Known Limitations:
- The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``. - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``.

View File

@ -42,6 +42,7 @@ class ModTracker:
def my_linear(m1, m2, bias): def my_linear(m1, m2, bias):
print(f"Current modules: {tracker.parents}") print(f"Current modules: {tracker.parents}")
return torch.mm(m1, m2.t()) + bias return torch.mm(m1, m2.t()) + bias
torch.nn.functional.linear = my_linear torch.nn.functional.linear = my_linear
mod(torch.rand(2, 2)) mod(torch.rand(2, 2))

View File

@ -255,9 +255,9 @@ class RuntimeEstimator(TorchDispatchMode):
Tuple[Any, float]: A tuple containing the result of the function and Tuple[Any, float]: A tuple containing the result of the function and
the mean operation time in milliseconds. the mean operation time in milliseconds.
""" """
assert isinstance( assert isinstance(cls.fake_mode, FakeTensorMode), (
cls.fake_mode, FakeTensorMode "Initialize/Assign FakeTensorMode before using this function"
), "Initialize/Assign FakeTensorMode before using this function" )
mean_op_time = 0.0 mean_op_time = 0.0
if func._overloadpacket not in _VIEW_OPS: if func._overloadpacket not in _VIEW_OPS:
try: try:
@ -289,9 +289,9 @@ class RuntimeEstimator(TorchDispatchMode):
Tuple[Any, float]: A tuple containing the result of the function and Tuple[Any, float]: A tuple containing the result of the function and
the mean operation time in milliseconds. the mean operation time in milliseconds.
""" """
assert ( assert torch.cuda.is_available(), (
torch.cuda.is_available() "Roofline estimation needs to access CUDA capabilities to make estimations"
), "Roofline estimation needs to access CUDA capabilities to make estimations" )
def get_num_bytes(t: torch.Tensor) -> int: def get_num_bytes(t: torch.Tensor) -> int:
""" """
@ -324,9 +324,9 @@ class RuntimeEstimator(TorchDispatchMode):
float: The estimated compute time in nanoseconds. float: The estimated compute time in nanoseconds.
""" """
if func_packet in flop_registry: if func_packet in flop_registry:
assert ( assert len(out_dtypes) == 1, (
len(out_dtypes) == 1 f"Only support single out dtype got {out_dtypes} for {func_packet}"
), f"Only support single out dtype got {out_dtypes} for {func_packet}" )
dtype = out_dtypes.pop() dtype = out_dtypes.pop()
# This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
peak_gpu_flops = get_device_tflops(dtype) * 1e15 peak_gpu_flops = get_device_tflops(dtype) * 1e15
@ -487,9 +487,9 @@ class RuntimeEstimator(TorchDispatchMode):
def __enter__(self) -> Self: def __enter__(self) -> Self:
fake_mode = active_fake_mode() fake_mode = active_fake_mode()
assert isinstance( assert isinstance(fake_mode, FakeTensorMode), (
fake_mode, FakeTensorMode "No FakeTensorMode found, designed to used under FakeTensorMode"
), "No FakeTensorMode found, designed to used under FakeTensorMode" )
RuntimeEstimator.fake_mode = fake_mode RuntimeEstimator.fake_mode = fake_mode
self.total_runtime = 0.0 self.total_runtime = 0.0
self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0))

View File

@ -245,7 +245,7 @@ class SACEstimator(TorchDispatchMode):
with FakeTensorMode(): with FakeTensorMode():
module = ... module = ...
inp = ... inp = ...
with sac_estimator('operator-level-cost-model'): with sac_estimator("operator-level-cost-model"):
output = module(inp) output = module(inp)
sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True)
""" """
@ -442,9 +442,9 @@ class SACEstimator(TorchDispatchMode):
out_storages_cpu.update(_get_untyped_storages(o)) out_storages_cpu.update(_get_untyped_storages(o))
# Check if there's more than 1 CUDA device # Check if there's more than 1 CUDA device
assert ( assert len(cuda_devices) <= 1, (
len(cuda_devices) <= 1 f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
), f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" )
# 2. Get the memory consumed by output # 2. Get the memory consumed by output
nbytes_cuda = sum( nbytes_cuda = sum(
@ -484,9 +484,9 @@ class SACEstimator(TorchDispatchMode):
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
acm_stats.sac_metadata.append(acm) acm_stats.sac_metadata.append(acm)
else: else:
assert ( assert mod_fqn == "Global", (
mod_fqn == "Global" f"Module {mod_fqn} not found in AC Mod Stats"
), f"Module {mod_fqn} not found in AC Mod Stats" )
self._sac_metadata.append(acm) self._sac_metadata.append(acm)
return out return out
@ -979,9 +979,9 @@ class SACEstimator(TorchDispatchMode):
def __enter__(self) -> Self: # type: ignore[no-untyped-def] def __enter__(self) -> Self: # type: ignore[no-untyped-def]
fake_mode = active_fake_mode() fake_mode = active_fake_mode()
assert isinstance( assert isinstance(fake_mode, FakeTensorMode), (
fake_mode, FakeTensorMode "SAC Estimator should be called in FakeTensorMode"
), "SAC Estimator should be called in FakeTensorMode" )
RuntimeEstimator.fake_mode = fake_mode RuntimeEstimator.fake_mode = fake_mode
self._mod_tracker.register_user_hooks( self._mod_tracker.register_user_hooks(
pre_fw_hook=self._pre_fw_hook, pre_fw_hook=self._pre_fw_hook,

View File

@ -38,9 +38,9 @@ def _perform_local_step(
""" """
overlap_info = zero._overlap_info overlap_info = zero._overlap_info
bucket_index = bucket.index() bucket_index = bucket.index()
assert ( assert len(zero.optim.param_groups) == 1, (
len(zero.optim.param_groups) == 1 "Overlapping DDP with ZeRO only supports a single parameter group"
), "Overlapping DDP with ZeRO only supports a single parameter group" )
# Construct the `gradients` input for the local optimizer step, which # Construct the `gradients` input for the local optimizer step, which
# expects `None` in a list position to indicate that the corresponding # expects `None` in a list position to indicate that the corresponding
@ -49,9 +49,9 @@ def _perform_local_step(
gradients: list[Optional[torch.Tensor]] = [ gradients: list[Optional[torch.Tensor]] = [
_NO_PARAM_UPDATE for _ in range(num_local_optim_params) _NO_PARAM_UPDATE for _ in range(num_local_optim_params)
] ]
assert ( assert bucket_index in overlap_info.offsets, (
bucket_index in overlap_info.offsets f"Bucket index {bucket_index} was not assigned to rank {rank}"
), f"Bucket index {bucket_index} was not assigned to rank {rank}" )
gradients_offset = overlap_info.offsets[bucket_index] gradients_offset = overlap_info.offsets[bucket_index]
bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
bucket_offset = bucket_assignment.offset bucket_offset = bucket_assignment.offset
@ -77,13 +77,13 @@ def _broadcast_bucket(
:class:`ZeroRedundancyOptimizer` instance. :class:`ZeroRedundancyOptimizer` instance.
""" """
overlap_info = zero._overlap_info overlap_info = zero._overlap_info
assert ( assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
len(overlap_info.assigned_ranks_per_bucket) > bucket_index "`assigned_ranks_per_bucket` is not fully constructed"
), "`assigned_ranks_per_bucket` is not fully constructed" )
# Sort to ensure the same ordering across ranks # Sort to ensure the same ordering across ranks
assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
assert len(assigned_ranks) > 0, ( assert len(assigned_ranks) > 0, (
f"Bucket {bucket_index} should be " "assigned to at least one rank" f"Bucket {bucket_index} should be assigned to at least one rank"
) )
for assigned_rank in assigned_ranks: for assigned_rank in assigned_ranks:
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
@ -273,9 +273,9 @@ def hook_with_zero_step(
rank = zero.global_rank rank = zero.global_rank
assert overlap_info.status == _OverlapStatus.INITIALIZED assert overlap_info.status == _OverlapStatus.INITIALIZED
assert ( assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
len(overlap_info.assigned_ranks_per_bucket) > bucket_index "`assigned_ranks_per_bucket` is not fully constructed"
), "`assigned_ranks_per_bucket` is not fully constructed" )
assigned_to_bucket = ( assigned_to_bucket = (
rank in overlap_info.assigned_ranks_per_bucket[bucket_index] rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
) )
@ -288,9 +288,9 @@ def hook_with_zero_step(
# Check that buckets are indexed incrementally starting from 0 in the # Check that buckets are indexed incrementally starting from 0 in the
# order of their autograd hooks firing # order of their autograd hooks firing
if len(overlap_info.bucket_indices_seen) > 0: if len(overlap_info.bucket_indices_seen) > 0:
assert ( assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, (
overlap_info.bucket_indices_seen[-1] == bucket_index - 1 "Bucket indices are not in incremental order"
), "Bucket indices are not in incremental order" )
else: else:
assert bucket_index == 0, "Bucket indices do not start from 0" assert bucket_index == 0, "Bucket indices do not start from 0"
overlap_info.bucket_indices_seen.append(bucket_index) overlap_info.bucket_indices_seen.append(bucket_index)

View File

@ -129,7 +129,7 @@ def bf16_compress_hook(
def fp16_compress_wrapper( def fp16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
""" """
Cast input tensor to ``torch.float16``, cast result of hook back to input dtype. Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
@ -167,7 +167,7 @@ def fp16_compress_wrapper(
def bf16_compress_wrapper( def bf16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
""" """
Warning: This API is experimental, and it requires NCCL version later than 2.9.6. Warning: This API is experimental, and it requires NCCL version later than 2.9.6.

View File

@ -223,8 +223,7 @@ class Join:
self._rank = dist.get_rank(self._process_group) self._rank = dist.get_rank(self._process_group)
self._device = device self._device = device
def __enter__(self): def __enter__(self): ...
...
def __exit__( def __exit__(
self, self,

View File

@ -52,7 +52,10 @@ def average_parameters(
def get_params_to_average( def get_params_to_average(
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]] params: Union[
Iterable[torch.nn.Parameter],
Iterable[dict[str, torch.nn.Parameter]],
],
): ):
""" """
Return a list of parameters that need to average. Return a list of parameters that need to average.

View File

@ -550,9 +550,7 @@ def create_default_global_save_plan(
new_item = dataclasses.replace(item, index=new_index) new_item = dataclasses.replace(item, index=new_index)
new_items.append(new_item) new_items.append(new_item)
assert ( assert item.tensor_data.chunk is not None, f"""
item.tensor_data.chunk is not None
), f"""
Cannot create MD for tensor without bounds. Cannot create MD for tensor without bounds.
FQN: {item.index.fqn} FQN: {item.index.fqn}
""" """

View File

@ -414,41 +414,33 @@ class FileSystemBase(ABC):
@abstractmethod @abstractmethod
def create_stream( def create_stream(
self, path: Union[str, os.PathLike], mode: str self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]: ) -> Generator[io.IOBase, None, None]: ...
...
@abstractmethod @abstractmethod
def concat_path( def concat_path(
self, path: Union[str, os.PathLike], suffix: str self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]: ) -> Union[str, os.PathLike]: ...
...
@abstractmethod @abstractmethod
def rename( def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None: ) -> None: ...
...
@abstractmethod @abstractmethod
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ...
...
@abstractmethod @abstractmethod
def mkdir(self, path: Union[str, os.PathLike]) -> None: def mkdir(self, path: Union[str, os.PathLike]) -> None: ...
...
@classmethod @classmethod
@abstractmethod @abstractmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ...
...
@abstractmethod @abstractmethod
def exists(self, path: Union[str, os.PathLike]) -> bool: def exists(self, path: Union[str, os.PathLike]) -> bool: ...
...
@abstractmethod @abstractmethod
def rm_file(self, path: Union[str, os.PathLike]) -> None: def rm_file(self, path: Union[str, os.PathLike]) -> None: ...
...
class FileSystem(FileSystemBase): class FileSystem(FileSystemBase):
@ -512,7 +504,6 @@ class FileSystem(FileSystemBase):
class _FileSystemWriter(StorageWriter): class _FileSystemWriter(StorageWriter):
""" """
Basic implementation of StorageWriter using file IO. Basic implementation of StorageWriter using file IO.
@ -800,9 +791,9 @@ class FileSystemReader(StorageReader):
) )
target_tensor = planner.resolve_tensor(req).detach() target_tensor = planner.resolve_tensor(req).detach()
assert ( assert target_tensor.size() == tensor.size(), (
target_tensor.size() == tensor.size() f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" )
target_tensor.copy_(tensor) target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor) planner.commit_tensor(req, target_tensor)

View File

@ -135,12 +135,12 @@ def _get_state_dict_2d_layout(
for key, value in state_dict.items(): for key, value in state_dict.items():
specs[key] = (None, value.size()) specs[key] = (None, value.size())
if _is_nested_tensor(value): if _is_nested_tensor(value):
assert ( assert len(value.local_shards()) == 1, (
len(value.local_shards()) == 1 "Cannot handle ST with multiple shards"
), "Cannot handle ST with multiple shards" )
assert isinstance( assert isinstance(value, ShardedTensor), (
value, ShardedTensor "Can only handle nested ShardedTensor"
), "Can only handle nested ShardedTensor" )
shard = value.local_shards()[0] shard = value.local_shards()[0]
specs[key] = ( specs[key] = (
shard.metadata.shard_offsets, shard.metadata.shard_offsets,

View File

@ -151,7 +151,7 @@ class SavePlanner(abc.ABC):
>>> storage_meta: Optional[StorageMeta], >>> storage_meta: Optional[StorageMeta],
>>> is_coordinator: bool, >>> is_coordinator: bool,
>>> ) -> None: >>> ) -> None:
>>> # prefix all keys with `foo_`` >>> # prefix all keys with `foo_``
>>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator) >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
@ -175,8 +175,8 @@ class SavePlanner(abc.ABC):
>>> from itertools import zip_longest >>> from itertools import zip_longest
>>> from dataclasses import replace >>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>> # This sample doesn't handle ShardedTensors >>> # This sample doesn't handle ShardedTensors
>>> def create_global_plan(self, all_plans): >>> def create_global_plan(self, all_plans):
>>> iters = [iter(all_plans[0].items)] * len(all_plans) >>> iters = [iter(all_plans[0].items)] * len(all_plans)
>>> items_per_rank = [ >>> items_per_rank = [
@ -347,7 +347,7 @@ class LoadPlanner:
>>> self.is_coordinator = is_coordinator >>> self.is_coordinator = is_coordinator
>>> >>>
>>> def load_bytes(self, read_item, value): >>> def load_bytes(self, read_item, value):
>>> # Remove the "foo_" prefix >>> # Remove the "foo_" prefix
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False) >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)

View File

@ -140,10 +140,12 @@ class StateDictOptions:
@dataclass @dataclass
class _StateDictInfo(StateDictOptions): class _StateDictInfo(StateDictOptions):
fqn_param_mapping: dict[ fqn_param_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] Union[str, torch.Tensor],
Union[FQNS_T, torch.Tensor],
] = field(default_factory=dict) ] = field(default_factory=dict)
shared_params_mapping: dict[ shared_params_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] Union[str, torch.Tensor],
Union[FQNS_T, torch.Tensor],
] = field(default_factory=dict) ] = field(default_factory=dict)
submodule_prefixes: set[str] = field(default_factory=set) submodule_prefixes: set[str] = field(default_factory=set)
handle_model: bool = True handle_model: bool = True
@ -1140,7 +1142,9 @@ def get_state_dict(
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
... fsdp_model, fsdp_optim
... )
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
>>> # the asserts will fail. >>> # the asserts will fail.

View File

@ -125,7 +125,9 @@ def load(
>>> my_model = MyModule() >>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters()) >>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict() >>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1") >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
... "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.load_state_dict( >>> torch.distributed.checkpoint.load_state_dict(
>>> state_dict=model_state_dict, >>> state_dict=model_state_dict,

View File

@ -127,7 +127,9 @@ def save(
>>> state_dict = {"model": my_model} >>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
... "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.save( >>> torch.distributed.checkpoint.save(
>>> state_dict=state_dict, >>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer, >>> storage_writer=fs_storage_writer,
@ -206,7 +208,9 @@ def async_save(
>>> state_dict = {"model": my_model} >>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
... "/checkpoint/1"
... )
>>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>> state_dict=state_dict, >>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer, >>> storage_writer=fs_storage_writer,
@ -223,7 +227,9 @@ def async_save(
pg = process_group or _get_default_group() pg = process_group or _get_default_group()
assert ( assert (
torch.device("cpu") in pg._device_types # type: ignore[attr-defined] torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" ), (
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
)
storage_writer = cast( storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)

View File

@ -32,7 +32,7 @@ R = TypeVar("R")
def _get_failure_dict( def _get_failure_dict(
results: list[Union[T, WRAPPED_EXCEPTION]] results: list[Union[T, WRAPPED_EXCEPTION]],
) -> dict[int, WRAPPED_EXCEPTION]: ) -> dict[int, WRAPPED_EXCEPTION]:
return cast( return cast(
dict[int, WRAPPED_EXCEPTION], dict[int, WRAPPED_EXCEPTION],

View File

@ -221,8 +221,12 @@ else:
if cur_rank in mesh_nd: if cur_rank in mesh_nd:
res_flattened_mesh = flattened_mesh res_flattened_mesh = flattened_mesh
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined] self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined] self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined] res_flattened_mesh # type: ignore[possibly-undefined]
)
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
flatten_dims_in_root
) # type: ignore[possibly-undefined]
return res_flattened_mesh return res_flattened_mesh
@ -242,9 +246,9 @@ else:
root_mesh = self.get_root_mesh(device_mesh) root_mesh = self.get_root_mesh(device_mesh)
child_mesh_dim_names = device_mesh.mesh_dim_names child_mesh_dim_names = device_mesh.mesh_dim_names
if root_mesh and child_mesh_dim_names: if root_mesh and child_mesh_dim_names:
assert ( assert len(child_mesh_dim_names) == 1, (
len(child_mesh_dim_names) == 1 "The submesh can only be a 1D mesh."
), "The submesh can only be a 1D mesh." )
child_mesh_dim_name = child_mesh_dim_names[0] child_mesh_dim_name = child_mesh_dim_names[0]
return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name)
return None return None
@ -763,7 +767,9 @@ else:
root_mesh, None root_mesh, None
) )
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index] dim_group_infos = root_to_flatten_mapping[
mesh_dim # type: ignore[index]
]._dim_group_infos[0][:2]
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos)) return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
else: else:
mesh_dim = ( mesh_dim = (
@ -905,9 +911,9 @@ else:
mesh_dim = 0 mesh_dim = 0
mesh_dim_group = not_none(self.get_group(mesh_dim)) mesh_dim_group = not_none(self.get_group(mesh_dim))
assert isinstance( assert isinstance(mesh_dim_group, ProcessGroup), (
mesh_dim_group, ProcessGroup "We expect ProcessGroup before calling `get_rank`!"
), "We expect ProcessGroup before calling `get_rank`!" )
return not_none(get_rank(mesh_dim_group)) return not_none(get_rank(mesh_dim_group))
def get_coordinate(self) -> Optional[list[int]]: def get_coordinate(self) -> Optional[list[int]]:

View File

@ -334,12 +334,12 @@ class Backend(str): # noqa: SLOT000
# Allow UCC plugin if Pytorch is not built with native support. # Allow UCC plugin if Pytorch is not built with native support.
# TODO: remove this exception once UCC plugin is fully deprecated. # TODO: remove this exception once UCC plugin is fully deprecated.
if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()):
assert not hasattr( assert not hasattr(Backend, name.upper()), (
Backend, name.upper() f"{name.upper()} c10d backend already exist"
), f"{name.upper()} c10d backend already exist" )
assert ( assert name.upper() not in Backend._plugins, (
name.upper() not in Backend._plugins f"{name.upper()} c10d backend creator function already exist"
), f"{name.upper()} c10d backend creator function already exist" )
setattr(Backend, name.upper(), name.lower()) setattr(Backend, name.upper(), name.lower())
Backend.backend_list.append(name.lower()) Backend.backend_list.append(name.lower())
@ -1650,9 +1650,9 @@ def init_process_group(
if "torch._dynamo" in sys.modules: if "torch._dynamo" in sys.modules:
torch._dynamo.trace_rules.clear_lru_cache() torch._dynamo.trace_rules.clear_lru_cache()
assert (store is None) or ( assert (store is None) or (init_method is None), (
init_method is None "Cannot specify both init_method and store."
), "Cannot specify both init_method and store." )
if store is not None: if store is not None:
assert world_size > 0, "world_size must be positive if using store" assert world_size > 0, "world_size must be positive if using store"
@ -1734,7 +1734,10 @@ def init_process_group(
) )
_update_default_pg(default_pg) _update_default_pg(default_pg)
_world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index] _world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index]
i: i
for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined]
}
_backend = _world.pg_map[not_none(GroupMember.WORLD)][0] _backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
_default_pg_init_method = init_method _default_pg_init_method = init_method
@ -1959,9 +1962,9 @@ def _new_process_group_helper(
if not is_nccl_available(): if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL built in") raise RuntimeError("Distributed package doesn't have NCCL built in")
if backend_options is not None: if backend_options is not None:
assert isinstance( assert isinstance(backend_options, ProcessGroupNCCL.Options), (
backend_options, ProcessGroupNCCL.Options "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
), "Expected backend_options argument to be of type ProcessGroupNCCL.Options" )
if backend_options._timeout != timeout: if backend_options._timeout != timeout:
warnings.warn( warnings.warn(
"backend_options._timeout was specified, " "backend_options._timeout was specified, "
@ -2001,9 +2004,9 @@ def _new_process_group_helper(
) )
backend_type = ProcessGroup.BackendType.XCCL backend_type = ProcessGroup.BackendType.XCCL
else: else:
assert ( assert backend_str.upper() in Backend._plugins, (
backend_str.upper() in Backend._plugins f"Unknown c10d backend type {backend_str.upper()}"
), f"Unknown c10d backend type {backend_str.upper()}" )
backend_plugin = Backend._plugins[backend_str.upper()] backend_plugin = Backend._plugins[backend_str.upper()]
creator_fn = backend_plugin.creator_fn creator_fn = backend_plugin.creator_fn
@ -2630,8 +2633,10 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
>>> # xdoctest: +SKIP("no rank") >>> # xdoctest: +SKIP("no rank")
>>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
>>> recv_tensor = torch.randn(2, dtype=torch.float32) >>> recv_tensor = torch.randn(2, dtype=torch.float32)
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size) >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
>>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size) >>> recv_op = dist.P2POp(
... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
... )
>>> reqs = batch_isend_irecv([send_op, recv_op]) >>> reqs = batch_isend_irecv([send_op, recv_op])
>>> for req in reqs: >>> for req in reqs:
>>> req.wait() >>> req.wait()
@ -2758,7 +2763,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
>>> # xdoctest: +SKIP("no rank") >>> # xdoctest: +SKIP("no rank")
>>> # All tensors below are of torch.int64 type. >>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks. >>> # We have 2 process groups, 2 ranks.
>>> device = torch.device(f'cuda:{rank}') >>> device = torch.device(f"cuda:{rank}")
>>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>> tensor >>> tensor
tensor([1, 2], device='cuda:0') # Rank 0 tensor([1, 2], device='cuda:0') # Rank 0
@ -2770,7 +2775,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
>>> # All tensors below are of torch.cfloat type. >>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks. >>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) >>> tensor = torch.tensor(
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>> tensor >>> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
@ -3380,9 +3387,9 @@ def recv_object_list(
) )
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src) rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
assert ( assert rank_sizes == rank_objects, (
rank_sizes == rank_objects "Mismatch in return ranks for object sizes and objects."
), "Mismatch in return ranks for object sizes and objects." )
# Deserialize objects using their stored sizes. # Deserialize objects using their stored sizes.
offset = 0 offset = 0
for i, obj_size in enumerate(object_sizes_tensor): for i, obj_size in enumerate(object_sizes_tensor):
@ -3673,8 +3680,10 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
>>> # xdoctest: +SKIP("need process group init") >>> # xdoctest: +SKIP("need process group init")
>>> # All tensors below are of torch.int64 dtype. >>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks. >>> # We have 2 process groups, 2 ranks.
>>> device = torch.device(f'cuda:{rank}') >>> device = torch.device(f"cuda:{rank}")
>>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)] >>> tensor_list = [
... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)
... ]
>>> tensor_list >>> tensor_list
[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
@ -3689,11 +3698,15 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
>>> # All tensors below are of torch.cfloat dtype. >>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks. >>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)] >>> tensor_list = [
... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)
... ]
>>> tensor_list >>> tensor_list
[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) >>> tensor = torch.tensor(
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>> tensor >>> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
@ -3769,7 +3782,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
>>> # xdoctest: +SKIP("need process group init") >>> # xdoctest: +SKIP("need process group init")
>>> # All tensors below are of torch.int64 dtype and on CUDA devices. >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
>>> # We have two ranks. >>> # We have two ranks.
>>> device = torch.device(f'cuda:{rank}') >>> device = torch.device(f"cuda:{rank}")
>>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>> tensor_in >>> tensor_in
tensor([1, 2], device='cuda:0') # Rank 0 tensor([1, 2], device='cuda:0') # Rank 0
@ -3969,8 +3982,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
) )
elif gather_list: elif gather_list:
raise ValueError( raise ValueError(
"Argument ``gather_list`` must NOT be specified " "Argument ``gather_list`` must NOT be specified on non-destination ranks."
"on non-destination ranks."
) )
@ -4141,8 +4153,7 @@ def scatter(
else: else:
if scatter_list: if scatter_list:
raise ValueError( raise ValueError(
"Argument ``scatter_list`` must NOT be specified " "Argument ``scatter_list`` must NOT be specified on non-source ranks."
"on non-source ranks."
) )
input_tensors = [] input_tensors = []
output_tensors = [tensor] output_tensors = [tensor]
@ -4225,7 +4236,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
>>> # xdoctest: +SKIP("need process group init") >>> # xdoctest: +SKIP("need process group init")
>>> # All tensors below are of torch.int64 dtype and on CUDA devices. >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
>>> # We have two ranks. >>> # We have two ranks.
>>> device = torch.device(f'cuda:{rank}') >>> device = torch.device(f"cuda:{rank}")
>>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device) >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
>>> # Input in concatenation form >>> # Input in concatenation form
>>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
@ -4381,7 +4392,7 @@ def all_to_all_single(
>>> # Essentially, it is similar to following operation: >>> # Essentially, it is similar to following operation:
>>> scatter_list = list(input.chunk(world_size)) >>> scatter_list = list(input.chunk(world_size))
>>> gather_list = list(output.chunk(world_size)) >>> gather_list = list(output.chunk(world_size))
>>> for i in range(world_size): >>> for i in range(world_size):
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
@ -4411,7 +4422,9 @@ def all_to_all_single(
>>> # Another example with tensors of torch.cfloat type. >>> # Another example with tensors of torch.cfloat type.
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) >>> input = torch.tensor(
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>> input >>> input
tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
@ -4510,7 +4523,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
>>> # Essentially, it is similar to following operation: >>> # Essentially, it is similar to following operation:
>>> scatter_list = input >>> scatter_list = input
>>> gather_list = output >>> gather_list = output
>>> for i in range(world_size): >>> for i in range(world_size):
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
@ -4544,7 +4557,9 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3
>>> # Another example with tensors of torch.cfloat type. >>> # Another example with tensors of torch.cfloat type.
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) >>> input = torch.tensor(
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>> input = list(input.chunk(4)) >>> input = list(input.chunk(4))
>>> input >>> input
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
@ -4882,9 +4897,9 @@ def split_group(
backend_config = BackendConfig(backend) backend_config = BackendConfig(backend)
if pg_options is not None: if pg_options is not None:
assert isinstance( assert isinstance(pg_options, ProcessGroupNCCL.Options), (
pg_options, ProcessGroupNCCL.Options "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" )
else: else:
# default pg_options same as the parent process group # default pg_options same as the parent process group
pg_options = parent_backend.options pg_options = parent_backend.options
@ -5086,9 +5101,9 @@ def _new_group_with_tag(
if device_id is None: if device_id is None:
device_id = default_pg.bound_device_id device_id = default_pg.bound_device_id
elif default_pg.bound_device_id is not None: elif default_pg.bound_device_id is not None:
assert ( assert device_id == default_pg.bound_device_id, (
device_id == default_pg.bound_device_id "Mismatched bound device between new pg and the default pg."
), "Mismatched bound device between new pg and the default pg." )
default_backend, default_store = _world.pg_map[default_pg] default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank() global_rank = default_pg.rank()
global_world_size = default_pg.size() global_world_size = default_pg.size()
@ -5408,9 +5423,9 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
def _find_or_create_pg_by_ranks_and_tag( def _find_or_create_pg_by_ranks_and_tag(
tag: str, ranks: list[int], stride: int tag: str, ranks: list[int], stride: int
) -> ProcessGroup: ) -> ProcessGroup:
assert ( assert len(ranks) % stride == 0, (
len(ranks) % stride == 0 f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" )
my_rank = get_rank() my_rank = get_rank()
my_ranks = None my_ranks = None

View File

@ -40,8 +40,9 @@ def worker_main() -> Generator[None, None, None]:
def main(): def main():
pass pass
if __name__=="__main__":
main() if __name__ == "__main__":
main()
""" """
with ExitStack() as stack: with ExitStack() as stack:

View File

@ -14,7 +14,10 @@ Example of usage:
:: ::
from torch.distributed.elastic import events from torch.distributed.elastic import events
event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
event = events.Event(
name="test_event", source=events.EventSource.WORKER, metadata={...}
)
events.get_logging_handler(destination="console").info(event) events.get_logging_handler(destination="console").info(event)
""" """

View File

@ -52,11 +52,12 @@ The example below measures the latency for the ``calculate()`` function.
metrics.configure(metrics.NullMetricsHandler()) metrics.configure(metrics.NullMetricsHandler())
metrics.configure(metrics.ConsoleMetricsHandler(), "my_module") metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
def my_method(): def my_method():
start = time.time() start = time.time()
calculate() calculate()
end = time.time() end = time.time()
metrics.put_metric("calculate_latency", int(end-start), "my_module") metrics.put_metric("calculate_latency", int(end - start), "my_module")
You may also use the torch.distributed.elastic.metrics.prof` decorator You may also use the torch.distributed.elastic.metrics.prof` decorator
to conveniently and succinctly profile functions to conveniently and succinctly profile functions
@ -70,15 +71,16 @@ to conveniently and succinctly profile functions
metrics.configure(metrics.ConsoleMetricsHandler(), "foobar") metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
metrics.configure(metrics.ConsoleMetricsHandler(), "Bar") metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
@metrics.prof @metrics.prof
def foo(): def foo():
pass pass
class Bar():
@metrics.prof class Bar:
def baz(): @metrics.prof
pass def baz():
pass
``@metrics.prof`` will publish the following metrics ``@metrics.prof`` will publish the following metrics
:: ::
@ -102,8 +104,8 @@ console.
import torch.distributed.elastic.metrics as metrics import torch.distributed.elastic.metrics as metrics
metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic") metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic")
metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app") metrics.configure(metrics.ConsoleMetricHandler(), group="my_app")
**Writing a Custom Metric Handler**: **Writing a Custom Metric Handler**:
@ -117,13 +119,15 @@ Below is a toy example that prints the metrics to ``stdout``
import torch.distributed.elastic.metrics as metrics import torch.distributed.elastic.metrics as metrics
class StdoutMetricHandler(metrics.MetricHandler): class StdoutMetricHandler(metrics.MetricHandler):
def emit(self, metric_data): def emit(self, metric_data):
ts = metric_data.timestamp ts = metric_data.timestamp
group = metric_data.group_name group = metric_data.group_name
name = metric_data.name name = metric_data.name
value = metric_data.value value = metric_data.value
print(f"[{ts}][{group}]: {name}={value}") print(f"[{ts}][{group}]: {name}={value}")
metrics.configure(StdoutMetricHandler(), group="my_app") metrics.configure(StdoutMetricHandler(), group="my_app")

View File

@ -123,6 +123,7 @@ def prof(fn=None, group: str = "torchelastic"):
def x(): def x():
pass pass
@metrics.prof(group="agent") @metrics.prof(group="agent")
def y(): def y():
pass pass

View File

@ -20,22 +20,23 @@ Usage 1: Launching two trainers as a function
from torch.distributed.elastic.multiprocessing import Std, start_processes from torch.distributed.elastic.multiprocessing import Std, start_processes
def trainer(a, b, c): def trainer(a, b, c):
pass # train pass # train
# runs two trainers # runs two trainers
# LOCAL_RANK=0 trainer(1,2,3) # LOCAL_RANK=0 trainer(1,2,3)
# LOCAL_RANK=1 trainer(4,5,6) # LOCAL_RANK=1 trainer(4,5,6)
ctx = start_processes( ctx = start_processes(
name="trainer", name="trainer",
entrypoint=trainer, entrypoint=trainer,
args={0: (1,2,3), 1: (4,5,6)}, args={0: (1, 2, 3), 1: (4, 5, 6)},
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
log_dir="/tmp/foobar", log_dir="/tmp/foobar",
redirects=Std.ALL, # write all worker stdout/stderr to a log file redirects=Std.ALL, # write all worker stdout/stderr to a log file
tee={0: Std.ERR}, # tee only local rank 0's stderr to console tee={0: Std.ERR}, # tee only local rank 0's stderr to console
) )
# waits for all copies of trainer to finish # waits for all copies of trainer to finish
ctx.wait() ctx.wait()

View File

@ -165,9 +165,11 @@ def to_map(
Example: Example:
:: ::
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} to_map(
{0: Std.OUT, 1: Std.OUT}, local_world_size=2
) # returns: {0: Std.OUT, 1: Std.OUT}
""" """
if isinstance(val_or_map, Std): if isinstance(val_or_map, Std):
return dict.fromkeys(range(local_world_size), val_or_map) return dict.fromkeys(range(local_world_size), val_or_map)
@ -304,7 +306,9 @@ class DefaultLogsSpecs(LogsSpecs):
if not self._run_log_dir: if not self._run_log_dir:
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload] attempt_log_dir = os.path.join(
self._run_log_dir, f"attempt_{restart_count}"
) # type: ignore[call-overload]
shutil.rmtree(attempt_log_dir, ignore_errors=True) shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir) os.makedirs(attempt_log_dir)
@ -868,9 +872,7 @@ class SubprocessContext(PContext):
if result.is_failed(): if result.is_failed():
first_failure = min(result.failures.values(), key=lambda f: f.timestamp) first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
logger.error( logger.error(
"failed (exitcode: %s)" "failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s",
" local_rank: %s (pid: %s)"
" of binary: %s",
first_failure.exitcode, first_failure.exitcode,
first_failure.local_rank, first_failure.local_rank,
first_failure.pid, first_failure.pid,

View File

@ -318,14 +318,14 @@ def record(
error_handler = get_error_handler() error_handler = get_error_handler()
error_handler.initialize() error_handler.initialize()
try: try:
foobar() foobar()
except ChildFailedError as e: except ChildFailedError as e:
_, failure = e.get_first_failure() _, failure = e.get_first_failure()
error_handler.dump_error_file(failure.error_file, failure.exitcode) error_handler.dump_error_file(failure.error_file, failure.exitcode)
raise raise
except Exception as e: except Exception as e:
error_handler.record_exception(e) error_handler.record_exception(e)
raise raise
.. important:: use this decorator once per process at the top level method, .. important:: use this decorator once per process at the top level method,
typically this is the main method. typically this is the main method.
@ -338,8 +338,9 @@ def record(
def main(): def main():
pass pass
if __name__=="__main__":
main() if __name__ == "__main__":
main()
""" """
if not error_handler: if not error_handler:

View File

@ -120,11 +120,7 @@ of the following implementations that come with PyTorch:
backend = C10dRendezvousBackend(store, "my_run_id") backend = C10dRendezvousBackend(store, "my_run_id")
rdzv_handler = DynamicRendezvousHandler.from_backend( rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id", run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
) )
""" """

View File

@ -89,8 +89,14 @@ class RendezvousStoreInfo:
addr = local_addr or socket.getfqdn() addr = local_addr or socket.getfqdn()
# When TCPStore is not shared, we fallback to get_free_port. # When TCPStore is not shared, we fallback to get_free_port.
port = server_port or get_free_port() port = server_port or get_free_port()
store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type] store.set(
store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] RendezvousStoreInfo.MASTER_ADDR_KEY,
addr.encode(encoding="UTF-8"), # type: ignore[arg-type]
)
store.set(
RendezvousStoreInfo.MASTER_PORT_KEY,
str(port).encode(encoding="UTF-8"), # type: ignore[arg-type]
)
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
port = int( port = int(

View File

@ -413,9 +413,9 @@ class EtcdRendezvous:
active_version = self.wait_for_peers(expected_version) active_version = self.wait_for_peers(expected_version)
state = json.loads(active_version.value) state = json.loads(active_version.value)
assert ( assert state["version"] == expected_version, (
state["version"] == expected_version "Logic error: failed to observe version mismatch"
), "Logic error: failed to observe version mismatch" )
return self.confirm_phase(expected_version, this_rank) return self.confirm_phase(expected_version, this_rank)
@ -533,9 +533,9 @@ class EtcdRendezvous:
"Rendezvous version changed. Must try join the new one." "Rendezvous version changed. Must try join the new one."
) )
assert ( assert len(state["participants"]) < self._num_max_workers, (
len(state["participants"]) < self._num_max_workers "Logic error: joinable rendezvous should always have space left"
), "Logic error: joinable rendezvous should always have space left" )
this_rank = len(state["participants"]) this_rank = len(state["participants"])
state["participants"].append(this_rank) state["participants"].append(this_rank)

View File

@ -86,11 +86,15 @@ def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
from torch.distributed.elastic.rendezvous import rendezvous_handler_registry from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
def create_my_rdzv(params: RendezvousParameters): def create_my_rdzv(params: RendezvousParameters):
return MyCustomRdzv(params) return MyCustomRdzv(params)
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters) my_rdzv_handler = get_rendezvous_handler(
"my_rdzv_backend_name", RendezvousParameters
)
""" """
return handler_registry.create_handler(params) return handler_registry.create_handler(params)

View File

@ -57,10 +57,10 @@ def get_all(store, rank: int, prefix: str, world_size: int):
:: ::
values = get_all(store, 'torchelastic/data', 3) values = get_all(store, "torchelastic/data", 3)
value1 = values[0] # retrieves the data for key torchelastic/data0 value1 = values[0] # retrieves the data for key torchelastic/data0
value2 = values[1] # retrieves the data for key torchelastic/data1 value2 = values[1] # retrieves the data for key torchelastic/data1
value3 = values[2] # retrieves the data for key torchelastic/data2 value3 = values[2] # retrieves the data for key torchelastic/data2
""" """
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])

View File

@ -2,6 +2,7 @@
""" """
This file includes private common utilities for FSDP. This file includes private common utilities for FSDP.
""" """
import logging import logging
import traceback import traceback
import warnings import warnings
@ -200,9 +201,9 @@ def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamH
# handles, meaning no entry in `_fully_sharded_module_to_handles` # handles, meaning no entry in `_fully_sharded_module_to_handles`
if state._handle is None: if state._handle is None:
return None return None
assert ( assert module in state._fully_sharded_module_to_handle, (
module in state._fully_sharded_module_to_handle f"Expects a fully sharded module but got {module} on rank {state.rank}"
), f"Expects a fully sharded module but got {module} on rank {state.rank}" )
return state._fully_sharded_module_to_handle[module] return state._fully_sharded_module_to_handle[module]
else: else:
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance. # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
@ -255,9 +256,9 @@ def _named_parameters_with_duplicates(
This API is required as some modules overwrite `named_parameters()` but do not support This API is required as some modules overwrite `named_parameters()` but do not support
`remove_duplicate`. `remove_duplicate`.
""" """
assert ( assert "remove_duplicate" not in kwargs, (
"remove_duplicate" not in kwargs "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." )
kwargs["remove_duplicate"] = False kwargs["remove_duplicate"] = False
try: try:
ret = list(module.named_parameters(**kwargs)) ret = list(module.named_parameters(**kwargs))

View File

@ -190,9 +190,9 @@ class _ExecOrderData:
return return
if self.is_first_iter: if self.is_first_iter:
msg_prefix = "Forward order differs across ranks:" msg_prefix = "Forward order differs across ranks:"
optional_local_indices: tuple[ optional_local_indices: tuple[Optional[int], ...] = (
Optional[int], ... self._get_handle_indices(handle)
] = self._get_handle_indices(handle) )
device = handle.device # guaranteed to be non-CPU device = handle.device # guaranteed to be non-CPU
num_valid_indices = sum( num_valid_indices = sum(
(index is not None) for index in optional_local_indices (index is not None) for index in optional_local_indices
@ -250,8 +250,7 @@ class _ExecOrderData:
( (
rank, rank,
world_indices[ world_indices[
rank rank * num_valid_indices : (rank + 1)
* num_valid_indices : (rank + 1)
* num_valid_indices * num_valid_indices
], ],
) )

View File

@ -586,7 +586,10 @@ class FlatParamHandle:
) )
self._fsdp_extension = fsdp_extension self._fsdp_extension = fsdp_extension
self._init_flat_param_and_metadata( self._init_flat_param_and_metadata(
params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type] params,
fully_sharded_module,
self._aligned_numel,
use_orig_params, # type: ignore[arg-type]
) )
self._use_unsharded_views(as_params=False) self._use_unsharded_views(as_params=False)
@ -978,9 +981,9 @@ class FlatParamHandle:
shard_param_infos = self._get_shard_metadata( shard_param_infos = self._get_shard_metadata(
unsharded_start_idx, unsharded_end_idx unsharded_start_idx, unsharded_end_idx
) )
assert ( assert len(shard_param_infos) == flat_param._num_params, (
len(shard_param_infos) == flat_param._num_params f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}" )
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined] flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
@ -996,9 +999,9 @@ class FlatParamHandle:
unsharded flat parameter specifying the shard. unsharded flat parameter specifying the shard.
""" """
flat_param_offsets = self._get_flat_param_offsets() flat_param_offsets = self._get_flat_param_offsets()
assert len(flat_param_offsets) == len( assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), (
self.flat_param._numels_with_padding f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" )
shard_param_infos: list[_ShardParamInfo] = [] shard_param_infos: list[_ShardParamInfo] = []
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
@ -1075,9 +1078,9 @@ class FlatParamHandle:
else: else:
chunk = chunks[rank] chunk = chunks[rank]
numel_to_pad = chunks[0].numel() - chunk.numel() numel_to_pad = chunks[0].numel() - chunk.numel()
assert ( assert numel_to_pad >= 0, (
numel_to_pad >= 0 "Chunk's size should be at most the first chunk's size"
), "Chunk's size should be at most the first chunk's size" )
return chunk, numel_to_pad return chunk, numel_to_pad
@staticmethod @staticmethod
@ -1302,7 +1305,8 @@ class FlatParamHandle:
self._check_low_precision_shard() self._check_low_precision_shard()
flat_param = self.flat_param flat_param = self.flat_param
_alloc_storage( _alloc_storage(
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined] flat_param._mp_shard,
flat_param._local_shard.size(), # type: ignore[attr-defined]
) )
# `copy_()` implicitly casts to the low precision # `copy_()` implicitly casts to the low precision
flat_param._mp_shard.copy_( # type: ignore[attr-defined] flat_param._mp_shard.copy_( # type: ignore[attr-defined]
@ -1498,7 +1502,8 @@ class FlatParamHandle:
# default stream suffices since the default stream waits for the # default stream suffices since the default stream waits for the
# unshard stream. # unshard stream.
_no_dispatch_record_stream( _no_dispatch_record_stream(
self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined] self.flat_param._mp_shard,
self._device_handle.current_stream(), # type: ignore[attr-defined]
) )
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
@ -1593,8 +1598,7 @@ class FlatParamHandle:
f"but got {flat_param.grad.device}", f"but got {flat_param.grad.device}",
) )
prev_iter_synced_gradients = ( prev_iter_synced_gradients = (
flat_param.grad.size() flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined]
== flat_param._local_shard.size() # type: ignore[attr-defined]
) )
if prev_iter_synced_gradients: if prev_iter_synced_gradients:
# TODO (awgu): Gradient accumulation outside `no_sync()` # TODO (awgu): Gradient accumulation outside `no_sync()`
@ -1668,8 +1672,7 @@ class FlatParamHandle:
cast_grad_to_param_dtype_if_needed(flat_param) cast_grad_to_param_dtype_if_needed(flat_param)
else: else:
_p_assert( _p_assert(
not self.uses_sharded_strategy not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined]
or not flat_param._post_backward_called, # type: ignore[attr-defined]
"All sharded parameters that received a gradient in the " "All sharded parameters that received a gradient in the "
"post-backward should use `_saved_grad_shard`", "post-backward should use `_saved_grad_shard`",
) )
@ -2504,7 +2507,8 @@ class FlatParamHandle:
"""Return the FQNs of the parameters present in this rank's shard.""" """Return the FQNs of the parameters present in this rank's shard."""
fqns_in_shard: list[str] = [] fqns_in_shard: list[str] = []
for fqn, shard_param_info in zip( for fqn, shard_param_info in zip(
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined] self.flat_param._fqns,
self.flat_param._shard_param_infos, # type: ignore[attr-defined]
): ):
if shard_param_info.in_shard: if shard_param_info.in_shard:
fqns_in_shard.append(fqn) fqns_in_shard.append(fqn)
@ -2694,7 +2698,7 @@ def _safe_setattr_tensor_or_param(
def _convert_to_params( def _convert_to_params(
tensors: list[Union[torch.Tensor, nn.Parameter]] tensors: list[Union[torch.Tensor, nn.Parameter]],
) -> list[nn.Parameter]: ) -> list[nn.Parameter]:
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]

View File

@ -374,9 +374,9 @@ def foreach_reduce(
for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)): for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)):
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0: if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
continue continue
assert ( assert unsharded_grad.size(shard_dim) % world_size == 0, (
unsharded_grad.size(shard_dim) % world_size == 0 f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" )
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim) chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
unsharded_grads[i] = torch.cat(chunks, dim=0) unsharded_grads[i] = torch.cat(chunks, dim=0)
padded_unsharded_sizes = tuple( padded_unsharded_sizes = tuple(

View File

@ -26,9 +26,9 @@ if torch._running_with_deploy():
else: else:
def detect_compiled_autograd(): def detect_compiled_autograd():
assert ( assert not torch.compiler.is_compiling(), (
not torch.compiler.is_compiling() "`detect_compiled_autograd()` is designed to be called in eager mode"
), "`detect_compiled_autograd()` is designed to be called in eager mode" )
global _compiled_autograd_enabled global _compiled_autograd_enabled
import torch._dynamo.compiled_autograd as ca import torch._dynamo.compiled_autograd as ca

View File

@ -304,9 +304,9 @@ class FSDPParam:
f"FSDP only supports 1D TP, not {self._tp_spec.placements}" f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
) )
split_factor = self._tp_spec.num_shards_map[shard_dim] split_factor = self._tp_spec.num_shards_map[shard_dim]
assert ( assert 2 <= self._spmd_mesh.ndim <= 3, (
2 <= self._spmd_mesh.ndim <= 3 f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." )
self._spmd_placements: tuple[Placement, ...] self._spmd_placements: tuple[Placement, ...]
dp_shard_tp_placement = ( dp_shard_tp_placement = (
( (
@ -520,8 +520,9 @@ class FSDPParam:
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
if hasattr(self, "_unsharded_param"): if hasattr(self, "_unsharded_param"):
assert compiled_autograd_enabled() assert compiled_autograd_enabled()
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( with (
self._unsharded_param torch.no_grad(),
torch.autograd._unsafe_preserve_version_counter(self._unsharded_param),
): ):
# NOTE: Under compile, if an unsharded param goes through # NOTE: Under compile, if an unsharded param goes through
# resize_(full) -> copy_ -> resize_(0) pattern, we will remove those # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
@ -785,9 +786,9 @@ class FSDPParam:
assert isinstance(grad, DTensor), f"{type(grad)}" assert isinstance(grad, DTensor), f"{type(grad)}"
placements = self._tp_spec.placements placements = self._tp_spec.placements
if placements != grad.placements: if placements != grad.placements:
assert len(self._tp_spec.placements) == len( assert len(self._tp_spec.placements) == len(grad.placements), (
grad.placements f"{self._tp_spec=} {grad.placements=}"
), f"{self._tp_spec=} {grad.placements=}" )
grad = grad.redistribute(placements=placements) grad = grad.redistribute(placements=placements)
grad = grad._local_tensor grad = grad._local_tensor
return grad return grad
@ -846,9 +847,9 @@ class FSDPParam:
shard_dim = self.fsdp_placement.dim shard_dim = self.fsdp_placement.dim
length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
if local_tensor.size() != padded_sharded_size: if local_tensor.size() != padded_sharded_size:
assert ( assert shard_dim == 0, (
shard_dim == 0 f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
), f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" )
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_( padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(
local_tensor local_tensor

View File

@ -424,9 +424,9 @@ class FSDPParamGroup:
if all_reduce_pg is None and self._all_reduce_hook_stream is not None: if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
# this means the native HSDP is not enabled, # this means the native HSDP is not enabled,
# but user may want to have a custom HSDP setup # but user may want to have a custom HSDP setup
assert ( assert self._all_reduce_hook is not None, (
self._all_reduce_hook is not None "all reduce hook stream is specified but hook itself is missing."
), "all reduce hook stream is specified but hook itself is missing." )
all_reduce_stream = self._all_reduce_hook_stream all_reduce_stream = self._all_reduce_hook_stream
else: else:
all_reduce_stream = self.comm_ctx.all_reduce_stream all_reduce_stream = self.comm_ctx.all_reduce_stream
@ -513,9 +513,10 @@ class FSDPParamGroup:
else: else:
raise ValueError(f"Unknown pass type: {pass_type}") raise ValueError(f"Unknown pass type: {pass_type}")
target_fqn = target_fsdp_param_group._module_fqn target_fqn = target_fsdp_param_group._module_fqn
with record_function( with (
f"FSDP::{pass_type}_prefetch for {target_fqn}" record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"),
), target_fsdp_param_group.use_training_state(training_state): target_fsdp_param_group.use_training_state(training_state),
):
async_op = target_fsdp_param_group.unshard_async_op async_op = target_fsdp_param_group.unshard_async_op
target_fsdp_param_group.unshard(async_op) target_fsdp_param_group.unshard(async_op)
@ -592,9 +593,9 @@ class FSDPParamGroup:
def _register_state_dict_hooks(self) -> None: def _register_state_dict_hooks(self) -> None:
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle) num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle) num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
assert ( assert num_pre_save_hooks == num_pre_load_hooks, (
num_pre_save_hooks == num_pre_load_hooks f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" )
if num_pre_save_hooks > 0: if num_pre_save_hooks > 0:
return # already registered return # already registered
modules_with_fsdp_params: set[nn.Module] = { modules_with_fsdp_params: set[nn.Module] = {
@ -605,12 +606,12 @@ class FSDPParamGroup:
self._to_sharded() self._to_sharded()
for module in modules_with_fsdp_params: for module in modules_with_fsdp_params:
self._module_to_pre_save_state_dict_hook_handle[ self._module_to_pre_save_state_dict_hook_handle[module] = (
module module.register_state_dict_pre_hook(to_sharded_hook)
] = module.register_state_dict_pre_hook(to_sharded_hook) )
self._module_to_pre_load_state_dict_hook_handle[ self._module_to_pre_load_state_dict_hook_handle[module] = (
module module._register_load_state_dict_pre_hook(to_sharded_hook)
] = module._register_load_state_dict_pre_hook(to_sharded_hook) )
# Properties # # Properties #
@property @property

View File

@ -60,8 +60,7 @@ def fully_shard(
mp_policy: MixedPrecisionPolicy = ..., mp_policy: MixedPrecisionPolicy = ...,
offload_policy: OffloadPolicy = ..., offload_policy: OffloadPolicy = ...,
ignored_params: Optional[set[nn.Parameter]] = ..., ignored_params: Optional[set[nn.Parameter]] = ...,
) -> FSDPModule: ) -> FSDPModule: ...
...
@overload @overload
@ -74,8 +73,7 @@ def fully_shard(
mp_policy: MixedPrecisionPolicy = ..., mp_policy: MixedPrecisionPolicy = ...,
offload_policy: OffloadPolicy = ..., offload_policy: OffloadPolicy = ...,
ignored_params: Optional[set[nn.Parameter]] = ..., ignored_params: Optional[set[nn.Parameter]] = ...,
) -> list[FSDPModule]: ) -> list[FSDPModule]: ...
...
# The decorator adds a state object to `module` that can be accessed via # The decorator adds a state object to `module` that can be accessed via

View File

@ -243,9 +243,9 @@ def _init_inter_node_process_group(
if local_rank == my_local_rank: if local_rank == my_local_rank:
inter_node_pg = grp inter_node_pg = grp
assert ( assert inter_node_pg is not None, (
inter_node_pg is not None f"{my_local_rank} expected to assign inter-node pg, but did not"
), f"{my_local_rank} expected to assign inter-node pg, but did not" )
return inter_node_pg return inter_node_pg

View File

@ -145,9 +145,9 @@ def _unflatten_optim_state(
dict will need to map these entries using the proper unflattened dict will need to map these entries using the proper unflattened
parameter IDs. parameter IDs.
""" """
assert ( assert not shard_state or to_save, (
not shard_state or to_save "If ``shard_state`` is True, ``to_save`` has to be True."
), "If ``shard_state`` is True, ``to_save`` has to be True." )
consolidated_state = _communicate_optim_state( consolidated_state = _communicate_optim_state(
fsdp_param_info, fsdp_param_info,
flat_param_state, flat_param_state,
@ -218,9 +218,9 @@ def _communicate_optim_state(
): ):
tensor_state[state_name] = value tensor_state[state_name] = value
continue continue
assert ( assert fsdp_state.compute_device is not None, (
fsdp_state.compute_device is not None "compute_device has not been initialized"
), "compute_device has not been initialized" )
if value.device.type != fsdp_state.compute_device.type: if value.device.type != fsdp_state.compute_device.type:
value = value.to(fsdp_state.compute_device) value = value.to(fsdp_state.compute_device)
# Assume that positive-dimension tensor optimizer state # Assume that positive-dimension tensor optimizer state
@ -394,7 +394,10 @@ def _shard_orig_param_state(
and value.dim() > 0 and value.dim() > 0
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
): ):
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator] value = value.flatten()[
intra_param_start_idx : intra_param_end_idx # type: ignore[operator]
+ 1
].clone()
new_optim_state[state_name] = value new_optim_state[state_name] = value
return new_optim_state return new_optim_state
@ -489,9 +492,9 @@ def _flatten_optim_state_dict(
if flat_state: if flat_state:
flat_osd_state[key] = flat_state flat_osd_state[key] = flat_state
elif use_orig_params: elif use_orig_params:
assert ( assert len(fqns) == 1, (
len(fqns) == 1 f"use_orig_params is True but there are multiple FQNs, {fqns}."
), f"use_orig_params is True but there are multiple FQNs, {fqns}." )
if optim is not None: # NamedOptimizer or KeyedOptimizer case. if optim is not None: # NamedOptimizer or KeyedOptimizer case.
state = optim.state.get(param, None) # type: ignore[call-overload] state = optim.state.get(param, None) # type: ignore[call-overload]
if state is not None: if state is not None:
@ -570,14 +573,13 @@ def _flatten_optim_state(
flat_param = handle.flat_param flat_param = handle.flat_param
num_unflat_params = len(unflat_param_names) num_unflat_params = len(unflat_param_names)
assert num_unflat_params > 0, ( assert num_unflat_params > 0, (
"Expects at least one unflattened parameter corresponding to the " "Expects at least one unflattened parameter corresponding to the flat parameter"
"flat parameter"
) )
unflat_param_shapes = flat_param._shapes unflat_param_shapes = flat_param._shapes
num_unflat_param_shapes = len(unflat_param_shapes) num_unflat_param_shapes = len(unflat_param_shapes)
assert ( assert num_unflat_params == num_unflat_param_shapes, (
num_unflat_params == num_unflat_param_shapes f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" )
# Check if these unflattened parameters have any optimizer state # Check if these unflattened parameters have any optimizer state
has_state = [ has_state = [
@ -759,8 +761,7 @@ def _flatten_tensor_optim_state(
flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel) flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined] flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined]
assert flat_tensor.shape == flat_param_shape, ( assert flat_tensor.shape == flat_param_shape, (
f"tensor optim state: {flat_tensor.shape} " f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}"
f"flat parameter: {flat_param_shape}"
) )
return flat_tensor return flat_tensor
@ -1065,9 +1066,9 @@ def _get_param_key_to_param(
""" """
clean_fqn_to_curr_fqn: dict[str, str] = {} clean_fqn_to_curr_fqn: dict[str, str] = {}
if is_named_optimizer: if is_named_optimizer:
assert ( assert param_to_fqns is not None and flat_param_to_fqn is not None, (
param_to_fqns is not None and flat_param_to_fqn is not None "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None." )
assert model is not None assert model is not None
for key, _ in _named_parameters_with_duplicates(model): for key, _ in _named_parameters_with_duplicates(model):
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
@ -1150,9 +1151,9 @@ def _check_missing_keys_on_rank(
continue continue
param_key = optim_state_key_to_param_key[r0_optim_state_key] param_key = optim_state_key_to_param_key[r0_optim_state_key]
if isinstance(param_key, int): if isinstance(param_key, int):
assert param_key >= 0 and param_key < len( assert param_key >= 0 and param_key < len(param_key_to_param), (
param_key_to_param "Check the `param_key_to_param` construction"
), "Check the `param_key_to_param` construction" )
# We cannot use FSDPState.compute_device as this API is a global view. # We cannot use FSDPState.compute_device as this API is a global view.
device = _get_pg_default_device(group) device = _get_pg_default_device(group)
num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)

View File

@ -121,9 +121,9 @@ def _all_gather_dtensor(
""" """
All gather a DTensor in its sharded dimension and return the local tensor. All gather a DTensor in its sharded dimension and return the local tensor.
""" """
assert ( assert root_mesh == tensor.device_mesh, (
root_mesh == tensor.device_mesh "The device mesh of a tensor should be a root mesh."
), "The device mesh of a tensor should be a root mesh." )
placements = list(copy.deepcopy(tensor.placements)) placements = list(copy.deepcopy(tensor.placements))
# FSDP placements: [Shard(0)] -> [Replicate()] # FSDP placements: [Shard(0)] -> [Replicate()]

View File

@ -466,9 +466,9 @@ def _local_pre_load_state_dict_hook(
) )
return return
load_tensor = state_dict[fqn] load_tensor = state_dict[fqn]
assert isinstance( assert isinstance(load_tensor, ShardedTensor), (
load_tensor, ShardedTensor "Tensors in local_state_dict should be ShardedTensor."
), "Tensors in local_state_dict should be ShardedTensor." )
# Convert the ShardedTensor to a Tensor. # Convert the ShardedTensor to a Tensor.
flat_param = _module_handle(fsdp_state, module).flat_param flat_param = _module_handle(fsdp_state, module).flat_param

View File

@ -143,9 +143,9 @@ class _ExecOrderTracer:
named_params = list(module.named_parameters()) named_params = list(module.named_parameters())
curr_module = exec_info.curr_module curr_module = exec_info.curr_module
if named_params: if named_params:
assert ( assert curr_module in exec_info.module_to_param_usage_infos, (
curr_module in exec_info.module_to_param_usage_infos "The current module should have already been processed by a patched `call_module`"
), "The current module should have already been processed by a patched `call_module`" )
exec_info.module_to_param_usage_infos[exec_info.curr_module].append( exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
_ParamUsageInfo(module, named_params) _ParamUsageInfo(module, named_params)
) )

View File

@ -185,9 +185,9 @@ def _unshard_fsdp_state_params(
yield yield
return return
assert ( assert handle._training_state == HandleTrainingState.IDLE, (
handle._training_state == HandleTrainingState.IDLE f"Expects the handle training to be IDLE but got {handle._training_state}"
), f"Expects the handle training to be IDLE but got {handle._training_state}" )
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS

View File

@ -306,16 +306,21 @@ class FullStateDictConfig(StateDictConfig):
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>> state = fsdp.state_dict() >>> state = fsdp.state_dict()
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0: >>> if dist.get_rank() == 0:
>>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>> state_dict = torch.load("my_checkpoint.pt") >>> state_dict = torch.load("my_checkpoint.pt")
>>> model.load_state_dict(state_dict) >>> model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) >>> fsdp = FSDP(
... model,
... device_id=torch.cuda.current_device(),
... auto_wrap_policy=...,
... sync_module_states=True,
... )
>>> # After this point, all ranks have FSDP model with loaded checkpoint. >>> # After this point, all ranks have FSDP model with loaded checkpoint.
Attributes: Attributes:

View File

@ -723,9 +723,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
if prev_state_dict_type is None: if prev_state_dict_type is None:
prev_state_dict_type = submodule._state_dict_type prev_state_dict_type = submodule._state_dict_type
else: else:
assert ( assert prev_state_dict_type == submodule._state_dict_type, (
prev_state_dict_type == submodule._state_dict_type "All FSDP modules should have the same state_dict_type."
), "All FSDP modules should have the same state_dict_type." )
if prev_state_dict_config is None: if prev_state_dict_config is None:
prev_state_dict_config = submodule._state_dict_config prev_state_dict_config = submodule._state_dict_config
else: else:
@ -738,7 +738,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
assert isinstance( assert isinstance(
submodule._optim_state_dict_config, submodule._optim_state_dict_config,
type(prev_optim_state_dict_config), type(prev_optim_state_dict_config),
), "All FSDP modules must have the same type of optim_state_dict_config." ), (
"All FSDP modules must have the same type of optim_state_dict_config."
)
submodule._state_dict_type = state_dict_type submodule._state_dict_type = state_dict_type
submodule._state_dict_config = state_dict_config submodule._state_dict_config = state_dict_config
@ -2153,9 +2155,9 @@ def _get_param_to_fqn(
""" """
param_to_param_names = _get_param_to_fqns(model) param_to_param_names = _get_param_to_fqns(model)
for param_names in param_to_param_names.values(): for param_names in param_to_param_names.values():
assert ( assert len(param_names) > 0, (
len(param_names) > 0 "`_get_param_to_fqns()` should not construct empty lists"
), "`_get_param_to_fqns()` should not construct empty lists" )
if len(param_names) > 1: if len(param_names) > 1:
raise RuntimeError( raise RuntimeError(
"Each parameter should only map to one parameter name but got " "Each parameter should only map to one parameter name but got "

View File

@ -112,20 +112,16 @@ class ShardedGradScaler(GradScaler):
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@overload @overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor: def scale(self, outputs: torch.Tensor) -> torch.Tensor: ...
...
@overload @overload
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ...
...
@overload @overload
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ...
...
@overload @overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ...
...
def scale( def scale(
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]] self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
@ -323,8 +319,10 @@ class ShardedGradScaler(GradScaler):
if isinstance(new_scale, float): if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr] self._scale.fill_(new_scale) # type: ignore[union-attr]
else: else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ reason = (
"new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
torch.FloatTensor with requires_grad=False." torch.FloatTensor with requires_grad=False."
)
assert new_scale.device.type == self._device, reason assert new_scale.device.type == self._device, reason
assert new_scale.numel() == 1, reason assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason assert new_scale.requires_grad is False, reason

View File

@ -61,9 +61,9 @@ def _post_order_apply(
"Non-root modules should have their module name set but got " "Non-root modules should have their module name set but got "
f"an empty module name for {module}" f"an empty module name for {module}"
) )
assert isinstance( assert isinstance(optional_module, nn.Module), (
optional_module, nn.Module f"fn should return None or an nn.Module but got {optional_module}"
), f"fn should return None or an nn.Module but got {optional_module}" )
setattr(parent_module, module_name, optional_module) setattr(parent_module, module_name, optional_module)
_post_order_apply_inner(root_module, "", None) _post_order_apply_inner(root_module, "", None)
@ -575,9 +575,9 @@ class _ConfigAutoWrap:
) )
_ConfigAutoWrap.in_autowrap_context = True _ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context. # Get and save the wrapper cls for the context.
assert ( assert "wrapper_cls" in kwargs.keys(), (
"wrapper_cls" in kwargs.keys() "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." )
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"] del kwargs["wrapper_cls"]
# Save the rest. # Save the rest.

View File

@ -183,8 +183,7 @@ def parse_args(args):
def launch(args): def launch(args):
if args.no_python and not args.use_env: if args.no_python and not args.use_env:
raise ValueError( raise ValueError(
"When using the '--no-python' flag," "When using the '--no-python' flag, you must also set the '--use-env' flag."
" you must also set the '--use-env' flag."
) )
run(args) run(args)

View File

@ -39,7 +39,10 @@ _REMOTE_MODULE_PICKLED_ATTRIBUTES = (
"module_rref", "module_rref",
) )
_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc] _SerializedRemoteModule = collections.namedtuple( # type: ignore[misc]
"_SerializedRemoteModule",
_REMOTE_MODULE_PICKLED_ATTRIBUTES,
)
# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled. # These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES # A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES

View File

@ -26,15 +26,15 @@ sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
def get_arg_return_types_from_interface(module_interface): def get_arg_return_types_from_interface(module_interface):
assert getattr( assert getattr(module_interface, "__torch_script_interface__", False), (
module_interface, "__torch_script_interface__", False "Expect a TorchScript class interface decorated by @torch.jit.interface."
), "Expect a TorchScript class interface decorated by @torch.jit.interface." )
qualified_name = torch._jit_internal._qualified_name(module_interface) qualified_name = torch._jit_internal._qualified_name(module_interface)
cu = torch.jit._state._python_cu cu = torch.jit._state._python_cu
module_interface_c = cu.get_interface(qualified_name) module_interface_c = cu.get_interface(qualified_name)
assert ( assert "forward" in module_interface_c.getMethodNames(), (
"forward" in module_interface_c.getMethodNames() f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}" )
method_schema = module_interface_c.getMethod("forward") method_schema = module_interface_c.getMethod("forward")
arg_str_list = [] arg_str_list = []

View File

@ -5,6 +5,7 @@ optimizer locally on the workers where the parameters live. The distributed
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
apply the gradients on each worker. apply the gradients on each worker.
""" """
import warnings import warnings
import torch import torch

View File

@ -44,10 +44,10 @@ def _apply_optimizer_in_backward(
param_1 = next(params_generator) param_1 = next(params_generator)
remainder_params = list(params_generator) remainder_params = list(params_generator)
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02}) apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02})
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04}) apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04})
model(...).sum().backward() # after backward, parameters will already model(...).sum().backward() # after backward, parameters will already
# have their registered optimizer(s) applied. # have their registered optimizer(s) applied.
""" """
@ -111,7 +111,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Opt
List[torch.optim.Optimizer]: the in-backward optimizers. List[torch.optim.Optimizer]: the in-backward optimizers.
Example:: Example::
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01}) _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
optims = _get_optimizers_in_backward(model) optims = _get_optimizers_in_backward(model)
""" """
optims: list[torch.optim.Optimizer] = [] optims: list[torch.optim.Optimizer] = []

View File

@ -147,12 +147,10 @@ class _NamedOptimizer(optim.Optimizer):
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
@overload @overload
def step(self, closure: None = ...) -> None: def step(self, closure: None = ...) -> None: ...
...
@overload @overload
def step(self, closure: Callable[[], float]) -> float: def step(self, closure: Callable[[], float]) -> float: ...
...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
""" """

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
r"""Zero Redundancy Optimizer.""" r"""Zero Redundancy Optimizer."""
import collections import collections
import copy import copy
import enum import enum
@ -262,9 +263,9 @@ class _OverlapInfo:
meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
in preparation for the next iteration. in preparation for the next iteration.
""" """
assert ( assert len(self.broadcast_handles) == self.num_bucket_assignments, (
len(self.broadcast_handles) == self.num_bucket_assignments f"Missing at least one broadcast handle on rank {dist.get_rank()}"
), f"Missing at least one broadcast handle on rank {dist.get_rank()}" )
_ = [x.wait() for x in self.broadcast_handles] _ = [x.wait() for x in self.broadcast_handles]
self.broadcast_handles.clear() self.broadcast_handles.clear()
@ -909,9 +910,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
params_per_rank = overlap_info.params_per_rank params_per_rank = overlap_info.params_per_rank
offsets = overlap_info.offsets offsets = overlap_info.offsets
self._bucket_assignments_per_rank_cache[assigned_rank][ self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = (
bucket_index _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) )
if self.global_rank == assigned_rank: if self.global_rank == assigned_rank:
offsets[bucket_index] = len(params_per_rank[assigned_rank]) offsets[bucket_index] = len(params_per_rank[assigned_rank])
params_per_rank[assigned_rank].extend(bucket_params) params_per_rank[assigned_rank].extend(bucket_params)
@ -927,9 +928,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
mapping bucket indices to :class:`_DDPBucketAssignment` s for each mapping bucket indices to :class:`_DDPBucketAssignment` s for each
rank. rank.
""" """
assert ( assert self._overlap_with_ddp, (
self._overlap_with_ddp "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" )
if len(self._bucket_assignments_per_rank_cache) > 0: if len(self._bucket_assignments_per_rank_cache) > 0:
return self._bucket_assignments_per_rank_cache return self._bucket_assignments_per_rank_cache
@ -1076,9 +1077,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
"Specifying `gradients` should not " "Specifying `gradients` should not "
"be used when `overlap_with_ddp=False`" "be used when `overlap_with_ddp=False`"
) )
assert ( assert closure is None, (
closure is None "`closure` is not supported when using a local functional optimizer"
), "`closure` is not supported when using a local functional optimizer" )
loss = self.optim.step(gradients=gradients) loss = self.optim.step(gradients=gradients)
# Sync any updated attributes in the local optimizer to the exposed # Sync any updated attributes in the local optimizer to the exposed
@ -1221,9 +1222,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
for rank, local_state_dict in enumerate(self._all_state_dicts): for rank, local_state_dict in enumerate(self._all_state_dicts):
local_param_groups = local_state_dict["param_groups"] local_param_groups = local_state_dict["param_groups"]
global_param_groups = self._partition_parameters()[rank] global_param_groups = self._partition_parameters()[rank]
assert len(local_param_groups) == len( assert len(local_param_groups) == len(global_param_groups), (
global_param_groups "Mismatch between number of local and global parameter groups"
), "Mismatch between number of local and global parameter groups" )
for local_param_group, global_param_group in zip( for local_param_group, global_param_group in zip(
local_param_groups, global_param_groups local_param_groups, global_param_groups
@ -1233,9 +1234,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
local_param_indices = local_param_group["params"] local_param_indices = local_param_group["params"]
global_params = global_param_group["params"] global_params = global_param_group["params"]
assert len(local_param_indices) == len( assert len(local_param_indices) == len(global_params), (
global_params "Mismatch between number of local and global parameters in parameter group"
), "Mismatch between number of local and global parameters in parameter group" )
for local_param_index, global_param in zip( for local_param_index, global_param in zip(
local_param_indices, global_params local_param_indices, global_params
): ):
@ -1268,9 +1269,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
dst_param_groups (list[dict]): parameter groups giving the dst_param_groups (list[dict]): parameter groups giving the
attribute settings to set. attribute settings to set.
""" """
assert len(src_param_groups) == len( assert len(src_param_groups) == len(dst_param_groups), (
dst_param_groups "Mismatch between number of source and destination parameter groups"
), "Mismatch between number of source and destination parameter groups" )
for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
# Sync all attributes except the parameters # Sync all attributes except the parameters
for attr in filter(lambda x: x != "params", src_param_group.keys()): for attr in filter(lambda x: x != "params", src_param_group.keys()):
@ -1479,9 +1480,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
The local optimizer is saved in ``self.optim``. The local optimizer is saved in ``self.optim``.
""" """
assert ( assert self._optim_constructor is not None, (
self._optim_constructor is not None "The local optimizer class has not been set"
), "The local optimizer class has not been set" )
param_groups = self._partition_parameters()[self.rank] param_groups = self._partition_parameters()[self.rank]
# `overlap_with_ddp=True` requires a local functional optimizer # `overlap_with_ddp=True` requires a local functional optimizer
@ -1508,7 +1509,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
"error due to an empty parameter list", "error due to an empty parameter list",
self._optim_constructor, self._optim_constructor,
) )
self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] self.optim: Any = self._optim_constructor(
params, **self._optim_defaults
) # type: ignore[no-redef]
# Log information about the DDP and ZeRO bucketing # Log information about the DDP and ZeRO bucketing
if dist.get_debug_level() != dist.DebugLevel.OFF: if dist.get_debug_level() != dist.DebugLevel.OFF:
@ -1531,7 +1534,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
else: else:
# NOTE: Passing `param_groups` into the local optimizer constructor # NOTE: Passing `param_groups` into the local optimizer constructor
# bypasses the empty parameter list check # bypasses the empty parameter list check
self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef] self.optim: Optimizer = self._optim_constructor(
param_groups, **self._optim_defaults
) # type: ignore[no-redef]
# TODO: Manually add `self.param_groups` if using a functional # TODO: Manually add `self.param_groups` if using a functional
# optimizer; remove this if/when the functional optimizers support # optimizer; remove this if/when the functional optimizers support

View File

@ -123,12 +123,11 @@ def _insert_stage_symbolic_backward(
# getitem calls. If we have a target other than getitem in this # getitem calls. If we have a target other than getitem in this
# (forward-only) code, there is a bug. # (forward-only) code, there is a bug.
assert node.target == operator.getitem, ( assert node.target == operator.getitem, (
"Found non-getitem call in forward pass. " "Found non-getitem call in forward pass. Please report a bug to PiPPy"
"Please report a bug to PiPPy" )
assert len(node.args) == 2, (
"Found malformed getitem call. Please report a bug to PiPPy"
) )
assert (
len(node.args) == 2
), "Found malformed getitem call. Please report a bug to PiPPy"
indexed_value, node_idx = tuple(node.args) indexed_value, node_idx = tuple(node.args)
# indexed_value is a collection that we are indexing into. It could # indexed_value is a collection that we are indexing into. It could
@ -249,8 +248,8 @@ class LossWrapper(torch.nn.Module):
targets value into the loss function, and get and return the loss value, which will targets value into the loss function, and get and return the loss value, which will
be backpropagated by PiPPy. The above class would then be instantiated like:: be backpropagated by PiPPy. The above class would then be instantiated like::
model = ... # instantiate the model model = ... # instantiate the model
loss_fn = torch.nn.MSELoss() # for the sake of demonstration loss_fn = torch.nn.MSELoss() # for the sake of demonstration
wrapper = MyModelWrapper(model, loss_fn) wrapper = MyModelWrapper(model, loss_fn)
pipe = Pipe.from_tracing(wrapper, ...) pipe = Pipe.from_tracing(wrapper, ...)
@ -818,9 +817,9 @@ class Pipe(torch.nn.Module):
# Get submodule # Get submodule
callee = root.get_submodule(callee_name) callee = root.get_submodule(callee_name)
assert not hasattr( assert not hasattr(callee, param_fqn), (
callee, param_fqn f"Module {callee_name} already has a parameter named {param_fqn}"
), f"Module {callee_name} already has a parameter named {param_fqn}" )
# Assign the parameter to the submodule # Assign the parameter to the submodule
if is_buffer: if is_buffer:
@ -979,7 +978,7 @@ class Pipe(torch.nn.Module):
else: else:
logger.debug("Pipeline is in inference mode, backward pass not generated") logger.debug("Pipeline is in inference mode, backward pass not generated")
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004 logger.debug(f"Full pipe model:\n{split}") # noqa: G004
return Pipe( return Pipe(
split, split,
@ -1184,7 +1183,7 @@ def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
except AttributeError as e: except AttributeError as e:
raise AttributeError( raise AttributeError(
f"Specified target {qualname} referenced " f"Specified target {qualname} referenced "
f'nonexistent module {".".join(atoms[: i + 1])}' f"nonexistent module {'.'.join(atoms[: i + 1])}"
) from e ) from e
mod_to_wrap = getattr(predecessor_module, atoms[-1]) mod_to_wrap = getattr(predecessor_module, atoms[-1])

View File

@ -306,17 +306,17 @@ def stage_backward(
if isinstance(output_val, torch.Tensor): if isinstance(output_val, torch.Tensor):
if not output_val.requires_grad and output_val.grad_fn is None: if not output_val.requires_grad and output_val.grad_fn is None:
return return
assert isinstance( assert isinstance(grad_val, (torch.Tensor, type(None))), (
grad_val, (torch.Tensor, type(None)) f"Expected Tensor or None gradient but got {type(grad_val)}"
), f"Expected Tensor or None gradient but got {type(grad_val)}" )
stage_output_tensors.append(output_val) stage_output_tensors.append(output_val)
output_grad_tensors.append(grad_val) output_grad_tensors.append(grad_val)
elif isinstance(output_val, (tuple, list)): elif isinstance(output_val, (tuple, list)):
if grad_val is None: if grad_val is None:
return return
assert isinstance( assert isinstance(grad_val, (tuple, list)), (
grad_val, (tuple, list) f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" )
assert len(output_val) == len(grad_val) assert len(output_val) == len(grad_val)
for ov, gv in zip(output_val, grad_val): for ov, gv in zip(output_val, grad_val):
extract_tensors_with_grads( extract_tensors_with_grads(
@ -350,7 +350,8 @@ def stage_backward(
) )
torch.autograd.backward( torch.autograd.backward(
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] stage_output_tensors,
grad_tensors=output_grad_tensors, # type: ignore[arg-type]
) )
# Extract gradients wrt the input values # Extract gradients wrt the input values

View File

@ -140,9 +140,9 @@ def _shard_dict_of_args(
real_num_chunks = num_chunks real_num_chunks = num_chunks
first_tensor = True first_tensor = True
assert len(args_dict) == len( assert len(args_dict) == len(args_chunk_spec), (
args_chunk_spec f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" )
for arg_key, arg in args_dict.items(): for arg_key, arg in args_dict.items():
flat, spec = tree_flatten(arg) flat, spec = tree_flatten(arg)

View File

@ -706,7 +706,9 @@ class Schedule1F1B(PipelineScheduleSingle):
recv_work.wait() recv_work.wait()
# Compute # Compute
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
) # type: ignore[index]
# Clear previous chunk's forward sends (hopefully they have well # Clear previous chunk's forward sends (hopefully they have well
# finished, otherwise, we are heavily communication bound, in which # finished, otherwise, we are heavily communication bound, in which
@ -762,7 +764,9 @@ class Schedule1F1B(PipelineScheduleSingle):
fuse_work.wait() fuse_work.wait()
# Now do the fwd # Now do the fwd
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
) # type: ignore[index]
# Compute loss # Compute loss
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
@ -992,9 +996,9 @@ def _add_send_recv(
progress = False progress = False
# go in order of ranks even if dict keys aren't ordered # go in order of ranks even if dict keys aren't ordered
for rank in sorted(compute_actions): for rank in sorted(compute_actions):
assert ( assert len(compute_actions[rank]) > 0, (
len(compute_actions[rank]) > 0 f"{rank=}, {len(compute_actions[rank])=}"
), f"{rank=}, {len(compute_actions[rank])=}" )
action = compute_actions[rank][0] action = compute_actions[rank][0]
if not _ready_to_schedule(action, prev_actions[rank]): if not _ready_to_schedule(action, prev_actions[rank]):
@ -1026,9 +1030,9 @@ def _validate_schedule(
num_stages: int, num_stages: int,
num_microbatches: int, num_microbatches: int,
) -> dict[int, int]: ) -> dict[int, int]:
assert ( assert len(actions) == pp_group_size, (
len(actions) == pp_group_size f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}" )
for rank in range(pp_group_size): for rank in range(pp_group_size):
assert rank in actions, f"Schedule is missing actions for rank {rank}" assert rank in actions, f"Schedule is missing actions for rank {rank}"
@ -1048,36 +1052,36 @@ def _validate_schedule(
for action in actions[rank]: for action in actions[rank]:
if action is None: if action is None:
continue continue
assert isinstance( assert isinstance(action, _Action), (
action, _Action f"Got an invalid action: {action}, expected instance of _Action"
), f"Got an invalid action: {action}, expected instance of _Action" )
s_id = action.stage_index s_id = action.stage_index
ctype = action.computation_type ctype = action.computation_type
mb_id = action.microbatch_index mb_id = action.microbatch_index
if ctype == F: if ctype == F:
stage_actions[s_id][F].add(mb_id) stage_actions[s_id][F].add(mb_id)
elif ctype == B: elif ctype == B:
assert ( assert mb_id in stage_actions[s_id][F], (
mb_id in stage_actions[s_id][F] f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward" )
stage_actions[s_id][B].add(mb_id) stage_actions[s_id][B].add(mb_id)
elif ctype == I: elif ctype == I:
assert ( assert mb_id in stage_actions[s_id][F], (
mb_id in stage_actions[s_id][F] f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward" )
stage_actions[s_id][I].add(mb_id) stage_actions[s_id][I].add(mb_id)
elif ctype == W: elif ctype == W:
assert ( assert mb_id in stage_actions[s_id][I], (
mb_id in stage_actions[s_id][I] f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input" )
stage_actions[s_id][W].add(mb_id) stage_actions[s_id][W].add(mb_id)
if s_id not in stage_index_to_rank_mapping: if s_id not in stage_index_to_rank_mapping:
stage_index_to_rank_mapping[s_id] = rank stage_index_to_rank_mapping[s_id] = rank
else: else:
existing_rank = stage_index_to_rank_mapping[s_id] existing_rank = stage_index_to_rank_mapping[s_id]
assert ( assert rank == existing_rank, (
rank == existing_rank f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
), f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" )
for s_id in stage_actions: for s_id in stage_actions:
f_mb = len(stage_actions[s_id][F]) f_mb = len(stage_actions[s_id][F])
@ -1085,14 +1089,14 @@ def _validate_schedule(
i_mb = len(stage_actions[s_id][I]) i_mb = len(stage_actions[s_id][I])
w_mb = len(stage_actions[s_id][W]) w_mb = len(stage_actions[s_id][W])
assert ( assert f_mb == num_microbatches, (
f_mb == num_microbatches f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
), f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" )
assert ( assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, (
b_mb + (i_mb + w_mb) // 2 == num_microbatches f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
), f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
but got B={b_mb}, I={i_mb}, W={w_mb}" but got B={b_mb}, I={i_mb}, W={w_mb}"
)
return stage_index_to_rank_mapping return stage_index_to_rank_mapping
@ -1289,9 +1293,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
computation_type = action.computation_type computation_type = action.computation_type
mb_index = action.microbatch_index mb_index = action.microbatch_index
stage_index = action.stage_index stage_index = action.stage_index
assert ( assert mb_index is not None, (
mb_index is not None "All currently supported action types require valid microbatch_index"
), "All currently supported action types require valid microbatch_index" )
if computation_type == _ComputationType.FORWARD: if computation_type == _ComputationType.FORWARD:
# perform forward computation # perform forward computation
stage = stage_index_to_stage[stage_index] stage = stage_index_to_stage[stage_index]
@ -1362,9 +1366,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
computation_type = prev_rank_action.computation_type computation_type = prev_rank_action.computation_type
mb_index = prev_rank_action.microbatch_index mb_index = prev_rank_action.microbatch_index
stage_index = prev_rank_action.stage_index stage_index = prev_rank_action.stage_index
assert ( assert mb_index is not None, (
mb_index is not None "All currently supported action types require valid microbatch_index"
), "All currently supported action types require valid microbatch_index" )
# Only handle sends for the forward from a previous rank # Only handle sends for the forward from a previous rank
if computation_type == _ComputationType.FORWARD: if computation_type == _ComputationType.FORWARD:
# If not the last stage, then receive fwd activations # If not the last stage, then receive fwd activations
@ -1393,9 +1397,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
computation_type = next_rank_action.computation_type computation_type = next_rank_action.computation_type
mb_index = next_rank_action.microbatch_index mb_index = next_rank_action.microbatch_index
stage_index = next_rank_action.stage_index stage_index = next_rank_action.stage_index
assert ( assert mb_index is not None, (
mb_index is not None "All currently supported action types require valid microbatch_index"
), "All currently supported action types require valid microbatch_index" )
# Only handle receives for the backwards from a next rank # Only handle receives for the backwards from a next rank
if computation_type in (FORWARD, BACKWARD_WEIGHT): if computation_type in (FORWARD, BACKWARD_WEIGHT):
# Next rank doing forward or weight update has no influence for the current rank backward recv # Next rank doing forward or weight update has no influence for the current rank backward recv
@ -1503,9 +1507,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename.""" """Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
# that it does not exist if it was created from a compute_comms schedule. # that it does not exist if it was created from a compute_comms schedule.
assert ( assert self.pipeline_order_with_comms is not None, (
self.pipeline_order_with_comms is not None "Must initialize compute_comms schedule before dump_csv"
), "Must initialize compute_comms schedule before dump_csv" )
with open(filename, "w", newline="") as csvfile: with open(filename, "w", newline="") as csvfile:
writer = csv.writer(csvfile) writer = csv.writer(csvfile)
for rank in self.pipeline_order_with_comms: for rank in self.pipeline_order_with_comms:
@ -1541,9 +1545,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
stage.stage_index: stage for stage in self._stages stage.stage_index: stage for stage in self._stages
} }
assert ( assert self.pipeline_order_with_comms is not None, (
self.pipeline_order_with_comms is not None "Must call _load_actions() before calling _step_microbatches()"
), "Must call _load_actions() before calling _step_microbatches()" )
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
bwd_recv_ops: dict[tuple[int, int], Work] = {} bwd_recv_ops: dict[tuple[int, int], Work] = {}
@ -1562,9 +1566,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
unshard_ops[stage_idx].wait() unshard_ops[stage_idx].wait()
del unshard_ops[stage_idx] del unshard_ops[stage_idx]
unsharded_stages.add(stage_idx) unsharded_stages.add(stage_idx)
assert ( assert stage_idx in unsharded_stages, (
stage_idx in unsharded_stages f"Attempted to compute on sharded {stage_idx=}"
), f"Attempted to compute on sharded {stage_idx=}" )
# count either full_backward or backward_weight together, to determine when to sync DP grads # count either full_backward or backward_weight together, to determine when to sync DP grads
backward_counter: Counter[int] = Counter() backward_counter: Counter[int] = Counter()
@ -1606,7 +1610,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
assert ( assert (
stage_idx, stage_idx,
mb_index, mb_index,
) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward" ) not in fwd_recv_ops, (
"Recv twice for {stage_idx=} {mb_index=} without executing forward"
)
fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
stage.get_fwd_recv_ops(mb_index) stage.get_fwd_recv_ops(mb_index)
) )
@ -1614,7 +1620,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
assert ( assert (
stage_idx, stage_idx,
mb_index, mb_index,
) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward" ) not in bwd_recv_ops, (
"Recv twice for {stage_idx=} {mb_index=} without executing backward"
)
bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
stage.get_bwd_recv_ops(mb_index) stage.get_bwd_recv_ops(mb_index)
) )
@ -1627,12 +1635,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator] unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator]
elif comp_type == RESHARD: elif comp_type == RESHARD:
if stage_uses_fsdp: if stage_uses_fsdp:
assert ( assert stage_idx in unsharded_stages, (
stage_idx in unsharded_stages f"Resharding {stage_idx=} without unsharding"
), f"Resharding {stage_idx=} without unsharding" )
assert ( assert stage_idx not in unshard_ops, (
stage_idx not in unshard_ops f"Resharding {stage_idx=} before finishing unshard"
), f"Resharding {stage_idx=} before finishing unshard" )
stage.submod.reshard() # type: ignore[operator] stage.submod.reshard() # type: ignore[operator]
elif comp_type == FORWARD: elif comp_type == FORWARD:
if stage_uses_fsdp: if stage_uses_fsdp:
@ -1739,7 +1747,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
) )
# TODO(whc) what is the best practice for printing a multiline log? # TODO(whc) what is the best practice for printing a multiline log?
# logger will split it into multiple log lines, but this makes it hard to read (too wide) # logger will split it into multiple log lines, but this makes it hard to read (too wide)
print(_format_pipeline_order(self.pipeline_order_with_comms, error_step_number=time_step)) # type: ignore[arg-type] print(
_format_pipeline_order(
self.pipeline_order_with_comms, # type: ignore[arg-type]
error_step_number=time_step,
)
)
raise e raise e
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them

View File

@ -243,16 +243,16 @@ class _PipelineStageBase(ABC):
configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
which could show up as hangs, silent corruption, or other errors. which could show up as hangs, silent corruption, or other errors.
""" """
assert ( assert self._outputs_meta is None, (
self._outputs_meta is None "Attempting to reconfigure output_meta, which is not supported"
), "Attempting to reconfigure output_meta, which is not supported" )
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: def get_outputs_meta(self) -> tuple[torch.Tensor, ...]:
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage""" """Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
assert ( assert self._outputs_meta is not None, (
self._outputs_meta is not None "Attempted to get_outputs_meta() without configuring output meta"
), "Attempted to get_outputs_meta() without configuring output meta" )
return self._outputs_meta return self._outputs_meta
def _create_grad_send_info( def _create_grad_send_info(
@ -358,12 +358,12 @@ class _PipelineStageBase(ABC):
prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs) prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs)
for info, tensor in zip(recv_infos, prev_stage_outputs): for info, tensor in zip(recv_infos, prev_stage_outputs):
assert isinstance( assert isinstance(tensor, torch.Tensor), (
tensor, torch.Tensor f"expected tensor values as outputs from prev stage, got {type(tensor)}"
), f"expected tensor values as outputs from prev stage, got {type(tensor)}" )
assert isinstance( assert isinstance(info, _RecvInfo), (
info, _RecvInfo "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo"
), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" )
# We don't need to do a data copy here, since we can directly pass the activation tensor reference from # We don't need to do a data copy here, since we can directly pass the activation tensor reference from
# one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve # one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve
@ -376,9 +376,9 @@ class _PipelineStageBase(ABC):
""" """
Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. Returns the input grad tensors for this stage, which correspond to the stage inputs during forward.
""" """
assert ( assert self.has_backward, (
self.has_backward "can't steal_bwd_input if this stage doesn't have backward"
), "can't steal_bwd_input if this stage doesn't have backward" )
assert not self.is_first, "can't get bwd output if this stage is first" assert not self.is_first, "can't get bwd output if this stage is first"
self._check_chunk_id(mb_index) self._check_chunk_id(mb_index)
@ -391,22 +391,22 @@ class _PipelineStageBase(ABC):
Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
Does not detach or set '_requires_grad'. Does not detach or set '_requires_grad'.
""" """
assert isinstance( assert isinstance(next_stage_bwd_outputs, tuple), (
next_stage_bwd_outputs, tuple f"Expected tuple, got {type(next_stage_bwd_outputs)}"
), f"Expected tuple, got {type(next_stage_bwd_outputs)}" )
assert ( assert self.has_backward, (
self.has_backward "can't set bwd input if this stage doesn't have backward"
), "can't set bwd input if this stage doesn't have backward" )
assert not self.is_last, "can't set bwd input if this stage is last" assert not self.is_last, "can't set bwd input if this stage is last"
recv_infos = self.grad_recv_info[mb_index] recv_infos = self.grad_recv_info[mb_index]
for info, tensor in zip(recv_infos, next_stage_bwd_outputs): for info, tensor in zip(recv_infos, next_stage_bwd_outputs):
assert isinstance( assert isinstance(tensor, torch.Tensor), (
tensor, torch.Tensor f"expected tensor values as outputs from prev stage, got {type(tensor)}"
), f"expected tensor values as outputs from prev stage, got {type(tensor)}" )
assert isinstance( assert isinstance(info, _RecvInfo), (
info, _RecvInfo f"Expected a recv info, got {type(info)}"
), f"Expected a recv info, got {type(info)}" )
info.buffer = tensor info.buffer = tensor
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
@ -1053,9 +1053,9 @@ class _PipelineStage(_PipelineStageBase):
# If the input is a getitem, we need to go deeper # If the input is a getitem, we need to go deeper
arg_node = arg_node.args[0] arg_node = arg_node.args[0]
assert ( assert arg_node.op == "call_module", (
arg_node.op == "call_module" f"Expecting call_module, got {arg_node.op}"
), f"Expecting call_module, got {arg_node.op}" )
src_stage = self.get_stage_index_of_submod(arg_node.name) src_stage = self.get_stage_index_of_submod(arg_node.name)
# Create a receive buffer for this placeholder # Create a receive buffer for this placeholder
@ -1081,7 +1081,8 @@ class _PipelineStage(_PipelineStageBase):
args_recv_info: list[InputInfo] = [] args_recv_info: list[InputInfo] = []
# Filter out placeholder nodes from `self.submod` (a GraphModule) # Filter out placeholder nodes from `self.submod` (a GraphModule)
placeholders = filter( # type: ignore[var-annotated] placeholders = filter( # type: ignore[var-annotated]
lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr] lambda node: node.op == "placeholder", # type: ignore[arg-type]
self.submod.graph.nodes, # type: ignore[arg-type,union-attr]
) )
# `placeholders` are nodes internal to submod. # `placeholders` are nodes internal to submod.
# `self.node.args` are dependency nodes in the outer graph. # `self.node.args` are dependency nodes in the outer graph.
@ -1300,9 +1301,9 @@ class PipelineStage(_PipelineStageBase):
raise RuntimeError( raise RuntimeError(
"Failed to perform pipeline shape inference- are your inputs on the same device as your module?" "Failed to perform pipeline shape inference- are your inputs on the same device as your module?"
) from e ) from e
assert ( assert output_args is not None, (
output_args is not None "If passing input_args, also pass output_args to override shape inference"
), "If passing input_args, also pass output_args to override shape inference" )
self._configure_outputs_meta( self._configure_outputs_meta(
(output_args,) if isinstance(output_args, torch.Tensor) else output_args (output_args,) if isinstance(output_args, torch.Tensor) else output_args
) )
@ -1346,9 +1347,9 @@ class PipelineStage(_PipelineStageBase):
) )
args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args) args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args)
else: else:
assert ( assert len(args) == 0, (
len(args) == 0 "Can't supply input args for shape inference on non-first stage"
), "Can't supply input args for shape inference on non-first stage" )
objects = [None] objects = [None]
logger.debug( logger.debug(
"Shape inference: stage %s receiving from stage %s", "Shape inference: stage %s receiving from stage %s",

View File

@ -80,9 +80,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
world_size = world_size_opt world_size = world_size_opt
if rank != -1 or world_size != -1 or world_size_opt is None: if rank != -1 or world_size != -1 or world_size_opt is None:
query_dict = _query_to_dict(result.query) query_dict = _query_to_dict(result.query)
assert ( assert "rank" not in query_dict and "world_size" not in query_dict, (
"rank" not in query_dict and "world_size" not in query_dict f"The url: {url} has node-specific arguments(rank, world_size) already."
), f"The url: {url} has node-specific arguments(rank, world_size) already." )
if rank != -1: if rank != -1:
query_dict["rank"] = str(rank) query_dict["rank"] = str(rank)
if world_size != -1 or world_size_opt is None: if world_size != -1 or world_size_opt is None:

View File

@ -137,13 +137,13 @@ def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
with _all_gather_dict_lock: with _all_gather_dict_lock:
if not worker_names: if not worker_names:
worker_names = _ALL_WORKER_NAMES worker_names = _ALL_WORKER_NAMES
assert ( assert worker_name in worker_names, (
worker_name in worker_names f"{worker_name} is not expected by leader."
), f"{worker_name} is not expected by leader." )
states = _all_gather_sequence_id_to_states[sequence_id] states = _all_gather_sequence_id_to_states[sequence_id]
assert ( assert worker_name not in states.gathered_objects, (
worker_name not in states.gathered_objects f"{worker_name} reported intent sequence id {sequence_id} twice. "
), f"{worker_name} reported intent sequence id {sequence_id} twice. " )
states.gathered_objects[worker_name] = obj states.gathered_objects[worker_name] = obj
if worker_names == set(states.gathered_objects.keys()): if worker_names == set(states.gathered_objects.keys()):
states.proceed_signal.set() states.proceed_signal.set()
@ -153,9 +153,9 @@ def _broadcast_to_followers(sequence_id, objects_map):
with _all_gather_dict_lock: with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states[sequence_id] states = _all_gather_sequence_id_to_states[sequence_id]
assert ( assert not states.proceed_signal.is_set(), (
not states.proceed_signal.is_set() f"Termination signal sequence id {sequence_id} got set twice."
), f"Termination signal sequence id {sequence_id} got set twice." )
states.gathered_objects = objects_map states.gathered_objects = objects_map
states.proceed_signal.set() states.proceed_signal.set()
@ -202,9 +202,9 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
function blocks until all workers have received the gathered results. function blocks until all workers have received the gathered results.
""" """
if not worker_names: if not worker_names:
assert ( assert _ALL_WORKER_NAMES is not None, (
_ALL_WORKER_NAMES is not None "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." )
worker_names = _ALL_WORKER_NAMES worker_names = _ALL_WORKER_NAMES
leader_name = min(worker_names) leader_name = min(worker_names)
@ -930,8 +930,7 @@ def _get_should_profile():
ActiveProfilerType = torch._C._profiler.ActiveProfilerType ActiveProfilerType = torch._C._profiler.ActiveProfilerType
return ( return (
torch.autograd._profiler_enabled() torch.autograd._profiler_enabled()
and torch._C._autograd._profiler_type() and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
) )

View File

@ -23,7 +23,7 @@ def _to_device(device: DeviceType) -> torch.device:
def _to_device_map( def _to_device_map(
device_map: dict[DeviceType, DeviceType] device_map: dict[DeviceType, DeviceType],
) -> dict[torch.device, torch.device]: ) -> dict[torch.device, torch.device]:
full_device_map: dict[torch.device, torch.device] = {} full_device_map: dict[torch.device, torch.device] = {}
reverse_map: dict[torch.device, torch.device] = {} reverse_map: dict[torch.device, torch.device] = {}
@ -127,7 +127,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
>>> options = TensorPipeRpcBackendOptions( >>> options = TensorPipeRpcBackendOptions(
>>> num_worker_threads=8, >>> num_worker_threads=8,
>>> device_maps={"worker1": {0: 1}} >>> device_maps={"worker1": {0: 1}}
>>> # maps worker0's cuda:0 to worker1's cuda:1 >>> # maps worker0's cuda:0 to worker1's cuda:1
>>> ) >>> )
>>> options.set_device_map("worker1", {1: 2}) >>> options.set_device_map("worker1", {1: 2})
>>> # maps worker0's cuda:1 to worker1's cuda:2 >>> # maps worker0's cuda:1 to worker1's cuda:2

View File

@ -63,10 +63,14 @@ class _server_process_global_profile(profile):
>>> import torch.distributed.rpc as rpc >>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> x, y = torch.tensor(1), torch.tensor(2) >>> x, y = torch.tensor(1), torch.tensor(2)
>>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) >>> outer_profile_rref = rpc.remote(
... dst_worker_name, rpc._server_process_global_profile
... )
>>> outer_profile_rref.rpc_sync().__enter__() >>> outer_profile_rref.rpc_sync().__enter__()
>>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
>>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) >>> inner_profile_rref = rpc.remote(
... dst_worker_name, rpc._server_process_global_profile
... )
>>> inner_profile_rref.rpc_sync().__enter__() >>> inner_profile_rref.rpc_sync().__enter__()
>>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
>>> inner_profile_rref.rpc_sync().__exit__(None, None, None) >>> inner_profile_rref.rpc_sync().__exit__(None, None, None)

View File

@ -289,9 +289,9 @@ Important Notices
:: ::
>>> # xdoctest: +SKIP("stub") >>> # xdoctest: +SKIP("stub")
>>> import torch.distributed as dist >>> import torch.distributed as dist
>>> dist.init_process_group(backend="gloo|nccl") >>> dist.init_process_group(backend="gloo|nccl")
3. In your training program, you can either use regular distributed functions 3. In your training program, you can either use regular distributed functions
or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
@ -302,9 +302,9 @@ Important Notices
:: ::
local_rank = int(os.environ["LOCAL_RANK"]) local_rank = int(os.environ["LOCAL_RANK"])
model = torch.nn.parallel.DistributedDataParallel(model, model = torch.nn.parallel.DistributedDataParallel(
device_ids=[local_rank], model, device_ids=[local_rank], output_device=local_rank
output_device=local_rank) )
Please ensure that ``device_ids`` argument is set to be the only GPU device id Please ensure that ``device_ids`` argument is set to be the only GPU device id
that your code will be operating on. This is generally the local rank of the that your code will be operating on. This is generally the local rank of the
@ -331,17 +331,18 @@ utility
:: ::
def main(): def main():
load_checkpoint(checkpoint_path) load_checkpoint(checkpoint_path)
initialize() initialize()
train() train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint: def train():
save_checkpoint(checkpoint_path) for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_checkpoint(checkpoint_path)
9. (Recommended) On worker errors, this tool will summarize the details of the error 9. (Recommended) On worker errors, this tool will summarize the details of the error
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
@ -353,17 +354,19 @@ utility
:: ::
from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.elastic.multiprocessing.errors import record
@record
def main():
# do train
pass
if __name__ == "__main__": @record
main() def main():
# do train
pass
if __name__ == "__main__":
main()
""" # noqa: E501 """ # noqa: E501
import os import os
import sys import sys
import uuid import uuid

View File

@ -297,9 +297,9 @@ class DTensor(torch.Tensor):
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
assert ( assert flatten_spec is not None, (
flatten_spec is not None "Expecting spec to be not None from `__tensor_flatten__` return value!"
), "Expecting spec to be not None from `__tensor_flatten__` return value!" )
local_tensor = inner_tensors["_local_tensor"] local_tensor = inner_tensors["_local_tensor"]
spec, requires_grad = flatten_spec spec, requires_grad = flatten_spec
unflatten_tensor_meta = TensorMeta( unflatten_tensor_meta = TensorMeta(
@ -694,9 +694,7 @@ def distribute_tensor(
xla_distribute_tensor, xla_distribute_tensor,
) )
return xla_distribute_tensor( return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value]
tensor, device_mesh, placements
) # type:ignore[return-value]
except ImportError as e: except ImportError as e:
msg = "To use DTensor API with xla, you must install the torch_xla package!" msg = "To use DTensor API with xla, you must install the torch_xla package!"
raise ImportError(msg) from e raise ImportError(msg) from e
@ -930,7 +928,9 @@ def distribute_module(
FutureWarning, FutureWarning,
stacklevel=2, stacklevel=2,
) )
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] module.register_forward_pre_hook(
lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg]
)
elif num_args == 3: elif num_args == 3:
# input_fn takes in module, inputs, device mesh # input_fn takes in module, inputs, device mesh
module.register_forward_pre_hook( module.register_forward_pre_hook(
@ -990,9 +990,9 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
# check device_mesh againts placements # check device_mesh againts placements
assert device_mesh.ndim == len( assert device_mesh.ndim == len(placements), (
placements "mesh dimension does not match the length of placements"
), "mesh dimension does not match the length of placements" )
assert kwargs["layout"] == torch.strided, "layout value not supported!" assert kwargs["layout"] == torch.strided, "layout value not supported!"
torch_stride = torch._prims_common.make_contiguous_strides_for(size) torch_stride = torch._prims_common.make_contiguous_strides_for(size)

View File

@ -75,7 +75,8 @@ def found_inf_reduce_handler(
) -> None: ) -> None:
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
local_tensor_args = pytree.tree_unflatten( local_tensor_args = pytree.tree_unflatten(
cast(list[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type] cast(list[object], op_info.local_args),
op_info.args_tree_spec, # type: ignore[arg-type]
) )
local_tensor_args = cast(tuple[object, ...], local_tensor_args) local_tensor_args = cast(tuple[object, ...], local_tensor_args)
op_call(*local_tensor_args, **op_info.local_kwargs) op_call(*local_tensor_args, **op_info.local_kwargs)
@ -200,8 +201,9 @@ class OpDispatcher:
# did not already construct one # did not already construct one
random._rng_tracker = random.OffsetBasedRNGTracker(mesh) random._rng_tracker = random.OffsetBasedRNGTracker(mesh)
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( first_arg, first_local_arg = (
torch.Tensor, local_tensor_args[0] cast(dtensor.DTensor, args[0]),
cast(torch.Tensor, local_tensor_args[0]),
) )
rng_context = ( rng_context = (
random._rng_tracker._distribute_region(first_arg._spec) random._rng_tracker._distribute_region(first_arg._spec)
@ -422,18 +424,18 @@ class OpDispatcher:
def wrap(res: object, spec: OutputSpecType) -> object: def wrap(res: object, spec: OutputSpecType) -> object:
if isinstance(res, torch.Tensor): if isinstance(res, torch.Tensor):
if spec is not None: if spec is not None:
assert isinstance( assert isinstance(spec, DTensorSpec), (
spec, DTensorSpec f"output spec does not match with output! Expected DTensorSpec, got {spec}."
), f"output spec does not match with output! Expected DTensorSpec, got {spec}." )
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
else: else:
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
assert res.ndim == 0, "output tensor should be scalar!" assert res.ndim == 0, "output tensor should be scalar!"
return res return res
elif isinstance(res, (list, tuple)): elif isinstance(res, (list, tuple)):
assert spec is not None and isinstance( assert spec is not None and isinstance(spec, (list, tuple)), (
spec, (list, tuple) f"output spec does not match with output! Expected list/tuple, got {spec}."
), f"output spec does not match with output! Expected list/tuple, got {spec}." )
res_list = [] res_list = []
for e, s in zip(res, spec): for e, s in zip(res, spec):
res_list.append(OpDispatcher.wrap(e, s)) res_list.append(OpDispatcher.wrap(e, s))

View File

@ -152,9 +152,9 @@ class OpStrategy(StrategyType):
if isinstance(output_spec, DTensorSpec): if isinstance(output_spec, DTensorSpec):
return output_spec.mesh.shape return output_spec.mesh.shape
else: else:
assert isinstance( assert isinstance(output_spec, tuple), (
output_spec, tuple "found no DTensorSpec in the OpStrategy!"
), "found no DTensorSpec in the OpStrategy!" )
assert output_spec[0] is not None assert output_spec[0] is not None
return output_spec[0].mesh.shape return output_spec[0].mesh.shape

View File

@ -63,9 +63,9 @@ class EinsumDims:
if is_batch_dim: if is_batch_dim:
batch_dims.append(dim_char) batch_dims.append(dim_char)
else: else:
assert ( assert len(input_dims) == 2, (
len(input_dims) == 2 "free dimension only supported for two inputs!"
), "free dimension only supported for two inputs!" )
lhs, rhs = input_dims lhs, rhs = input_dims
if dim_char in lhs: if dim_char in lhs:
lhs_out_only_dims.append(dim_char) lhs_out_only_dims.append(dim_char)

View File

@ -89,9 +89,9 @@ class _MaskPartial(Partial):
# override parent logic to perform partial mask for embedding # override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim) num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim # get local shard size and offset on the embedding_dim
assert ( assert self.offset_shape is not None, (
self.offset_shape is not None "offset_shape needs to be set for _MaskPartial"
), "offset_shape needs to be set for _MaskPartial" )
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
self.offset_shape[self.offset_dim], self.offset_shape[self.offset_dim],
num_chunks, num_chunks,

View File

@ -994,9 +994,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
) )
output_specs_list.append(weight_out_spec if output_mask[1] else None) output_specs_list.append(weight_out_spec if output_mask[1] else None)
else: else:
assert ( assert output_mask[1] is False, (
output_mask[1] is False "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
), "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." )
output_specs_list.append(None) output_specs_list.append(None)
# arg: bias # arg: bias
@ -1020,9 +1020,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
) )
output_specs_list.append(bias_out_spec if output_mask[2] else None) output_specs_list.append(bias_out_spec if output_mask[2] else None)
else: else:
assert ( assert output_mask[2] is False, (
output_mask[2] is False "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
), "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." )
output_specs_list.append(None) output_specs_list.append(None)
out_tuple_strategy.strategies.append( out_tuple_strategy.strategies.append(

View File

@ -155,9 +155,9 @@ def _scaled_mm_like_strategy(
assert isinstance(scale_mat2_strategy, OpStrategy) assert isinstance(scale_mat2_strategy, OpStrategy)
# TODO: add support for these later # TODO: add support for these later
assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias" assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias"
assert ( assert scale_result_strategy is None, (
scale_result_strategy is None "_scaled_mm on DTensors doesn't support scale_result"
), "_scaled_mm on DTensors doesn't support scale_result" )
# generate all possible strategies for mm # generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh) mm_strategy = gen_einsum_strategies(mm_equation, mesh)
# filter out invalid strategies and associate costs # filter out invalid strategies and associate costs

View File

@ -445,9 +445,9 @@ def pointwise_strategy(
followed_strategy = op_schema.args_schema[max_shards_strategy_index] followed_strategy = op_schema.args_schema[max_shards_strategy_index]
assert isinstance( assert isinstance(followed_strategy, OpStrategy), (
followed_strategy, OpStrategy f"no strategy to follow for {op_schema}!"
), f"no strategy to follow for {op_schema}!" )
return common_pointwise_strategy( return common_pointwise_strategy(
mesh, op_schema.args_schema, followed_strategy, linearity mesh, op_schema.args_schema, followed_strategy, linearity
) )

View File

@ -254,9 +254,9 @@ def dim_movedim(
def dim_repeat(ndim: int, sizes: Shape) -> DimMap: def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
sizes = normalize_sizes(sizes) sizes = normalize_sizes(sizes)
assert ( assert len(sizes) >= ndim, (
len(sizes) >= ndim f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." )
pad = len(sizes) - ndim pad = len(sizes) - ndim
return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
@ -275,9 +275,9 @@ def infer_size(total_size: int, sizes: Shape) -> Shape:
if infers: if infers:
size = -size size = -size
missing_size = total_size // size missing_size = total_size // size
assert ( assert total_size % size == 0, (
total_size % size == 0 f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." )
return tuple(s if s != -1 else missing_size for s in sizes) return tuple(s if s != -1 else missing_size for s in sizes)
assert size == total_size, f"sizes do not match {total_size} vs {size}" assert size == total_size, f"sizes do not match {total_size} vs {size}"
return sizes return sizes
@ -538,9 +538,9 @@ def propagate_shape_and_sharding(
for size, shard in zip(mesh_sizes, input_src_placements): for size, shard in zip(mesh_sizes, input_src_placements):
if isinstance(shard, Shard) and shard.dim == in_dim: if isinstance(shard, Shard) and shard.dim == in_dim:
submesh_size *= size submesh_size *= size
assert ( assert out_size % submesh_size == 0, (
out_size % submesh_size == 0 f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." )
# we will only shard our first component of the split # we will only shard our first component of the split
return in_dim if cmd.split_id == 0 else None return in_dim if cmd.split_id == 0 else None

View File

@ -45,7 +45,7 @@ def register_prop_rule(
# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated. # pyre-fixme[2]: Parameter must be annotated.
def wrapper( def wrapper(
impl: Callable[[OpSchema], OutputSharding] impl: Callable[[OpSchema], OutputSharding],
) -> Callable[[OpSchema], OutputSharding]: ) -> Callable[[OpSchema], OutputSharding]:
overloads = op if isinstance(op, list) else [op] overloads = op if isinstance(op, list) else [op]
for overload in overloads: for overload in overloads:
@ -102,7 +102,7 @@ def register_op_strategy(
def as_list( def as_list(
x: Union[list[object], object] x: Union[list[object], object],
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] ) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,

View File

@ -231,9 +231,9 @@ def redistribute_local_tensor(
local_tensor, device_mesh, i, my_coordinate[i] local_tensor, device_mesh, i, my_coordinate[i]
) )
else: else:
assert ( assert current.is_shard(), (
current.is_shard() f"Current placement should be shard but found {current}"
), f"Current placement should be shard but found {current}" )
shard_spec = cast(Shard, current) shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim: if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim( new_local_tensor = shard_spec._to_new_shard_dim(

View File

@ -487,9 +487,9 @@ class ShardingPropagator:
strategy_costs: list[float] = [] strategy_costs: list[float] = []
for strtg in strategy.strategies: for strtg in strategy.strategies:
assert ( assert strtg.redistribute_cost is not None, (
strtg.redistribute_cost is not None "must set redistribute cost each strategy!"
), "must set redistribute cost each strategy!" )
redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
strategy_costs.append(redistribute_cost) strategy_costs.append(redistribute_cost)

View File

@ -73,9 +73,9 @@ def compute_local_shape_and_global_offset(
if isinstance(placement, Shard): if isinstance(placement, Shard):
shard_dim = placement.dim shard_dim = placement.dim
local_offset = [0] * len(global_shape) local_offset = [0] * len(global_shape)
assert shard_dim < len( assert shard_dim < len(local_shape), (
local_shape f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" )
shard_size, shard_offset = placement._local_shard_size_on_dim( shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim], local_shape[shard_dim],
mesh_dim_size, mesh_dim_size,
@ -141,16 +141,15 @@ def compute_local_shape_and_global_offset(
if isinstance(placement, _StridedShard): if isinstance(placement, _StridedShard):
strided_part_seen[shard_dim] = True strided_part_seen[shard_dim] = True
shard_idx_stride_by_mesh_dim[shard_dim][ shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
idx num_shards_by_tensor_dim[shard_dim]
] = num_shards_by_tensor_dim[shard_dim] // ( // (placement.split_factor * mesh_dim_size)
placement.split_factor * mesh_dim_size
) )
else: else:
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
shard_idx_stride_by_mesh_dim[shard_dim][ shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
idx num_shards_by_tensor_dim[shard_dim]
] = num_shards_by_tensor_dim[shard_dim] )
shard_idx = [ shard_idx = [
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)]) sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
@ -205,9 +204,9 @@ def compute_global_tensor_info(
) )
shard_dim = shard_placement.dim shard_dim = shard_placement.dim
assert ( assert shard_dim < tensor.ndim, (
shard_dim < tensor.ndim f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." )
local_dim_size = tensor_shape[shard_dim] local_dim_size = tensor_shape[shard_dim]
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size tensor_shape[shard_dim] = local_dim_size * mesh_dim_size

View File

@ -283,9 +283,9 @@ class CommDebugMode(TorchDispatchMode):
"module_type" in self.advanced_module_tracker.module_helper_dict[fqn] "module_type" in self.advanced_module_tracker.module_helper_dict[fqn]
and include_module_data and include_module_data
): ):
json_dict[ json_dict["module_type"] = (
"module_type" self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"] )
if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
for ( for (
@ -659,9 +659,9 @@ class CommDebugMode(TorchDispatchMode):
operation_dict["is_bw"] = self.advanced_module_tracker.is_bw operation_dict["is_bw"] = self.advanced_module_tracker.is_bw
# tracks if the operation is part of activation checkpointing # tracks if the operation is part of activation checkpointing
operation_dict[ operation_dict["is_activation_checkpointing"] = (
"is_activation_checkpointing" self.advanced_module_tracker.activation_checkpointing
] = self.advanced_module_tracker.activation_checkpointing )
if any(t == DTensor for t in types): if any(t == DTensor for t in types):
for ele in args: for ele in args:

View File

@ -108,9 +108,9 @@ def _compute_local_shape_and_global_offset(
if isinstance(placement, Shard): if isinstance(placement, Shard):
shard_dim = placement.dim shard_dim = placement.dim
local_offset = [0] * len(global_shape) local_offset = [0] * len(global_shape)
assert shard_dim < len( assert shard_dim < len(local_shape), (
local_shape f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" )
shard_size, shard_offset = placement._local_shard_size_on_dim( shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim], local_shape[shard_dim],
mesh_dim_size, mesh_dim_size,

View File

@ -2,6 +2,7 @@
To run the example, use the following command: To run the example, use the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing
""" """
import argparse import argparse
import os import os
from typing import Callable, Union from typing import Callable, Union

View File

@ -6,6 +6,7 @@ with intermediate activations sharded across mutliple GPUs via DTensor
To run the example, use the following command: To run the example, use the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
""" """
import os import os
import time import time

View File

@ -3,6 +3,7 @@
The following example demonstrates how to represent torchrec's embedding The following example demonstrates how to represent torchrec's embedding
sharding with the DTensor API. sharding with the DTensor API.
""" """
import argparse import argparse
import os import os
from functools import cached_property from functools import cached_property

View File

@ -253,22 +253,18 @@ class _AttentionOp(Protocol):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
**kwargs: object, **kwargs: object,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]: ...
...
class _RingRotater(ABC): class _RingRotater(ABC):
@abstractmethod @abstractmethod
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ...
...
@abstractmethod @abstractmethod
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ...
...
@abstractmethod @abstractmethod
def next_buffer(self) -> torch.Tensor: def next_buffer(self) -> torch.Tensor: ...
...
class _AllToAllRotater(_RingRotater): class _AllToAllRotater(_RingRotater):
@ -1097,15 +1093,13 @@ class _LoadBalancer(ABC):
@abstractmethod @abstractmethod
def shard( def shard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor: ) -> torch.Tensor: ...
...
@classmethod @classmethod
@abstractmethod @abstractmethod
def unshard( def unshard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor: ) -> torch.Tensor: ...
...
class _SequentialSharder(_LoadBalancer): class _SequentialSharder(_LoadBalancer):
@ -1147,9 +1141,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
def shard( def shard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor: ) -> torch.Tensor:
assert ( assert cls.ROUND_ROBIN_CYCLE == 2, (
cls.ROUND_ROBIN_CYCLE == 2 "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." )
cp_world_size = mesh.size() cp_world_size = mesh.size()
cp_rank = mesh.get_local_rank() cp_rank = mesh.get_local_rank()
assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0 assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0
@ -1163,9 +1157,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
def unshard( def unshard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor: ) -> torch.Tensor:
assert ( assert cls.ROUND_ROBIN_CYCLE == 2, (
cls.ROUND_ROBIN_CYCLE == 2 "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." )
buffer = buffer.contiguous() buffer = buffer.contiguous()
cp_world_size = mesh.size() cp_world_size = mesh.size()

View File

@ -113,9 +113,15 @@ def local_map(
>>> device_mesh=device_mesh, >>> device_mesh=device_mesh,
>>> ) >>> )
>>> >>>
>>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor >>> W_dt = distribute_tensor(
>>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor ... W, device_mesh, (col_wise)
>>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors ... ) # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(
... X, device_mesh, (row_wise)
... ) # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(
... device_mesh, W_dt, X_dt
... ) # apply local_mm_allreduce_forward to DTensors
.. note:: This API is currently experimental and subject to change .. note:: This API is currently experimental and subject to change
""" """
@ -151,9 +157,9 @@ def local_map(
) )
if in_placements is not None: if in_placements is not None:
spec = in_placements[idx] spec = in_placements[idx]
assert ( assert spec is not None, (
spec is not None f"DTensor input {arg} expects placements but received {spec}!"
), f"DTensor input {arg} expects placements but received {spec}!" )
if not isinstance(spec, tuple): if not isinstance(spec, tuple):
spec = tuple(spec) spec = tuple(spec)
@ -208,17 +214,17 @@ def local_map(
) )
for out, spec in zip(flat_out, out_placements_tuple): for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor): if isinstance(out, torch.Tensor):
assert not isinstance( assert not isinstance(out, DTensor), (
out, DTensor f"torch.Tensor output expected but received {type(out)}: {out}"
), f"torch.Tensor output expected but received {type(out)}: {out}" )
flat_dist_out.append( flat_dist_out.append(
DTensor.from_local(out, device_mesh, spec, run_check=False) DTensor.from_local(out, device_mesh, spec, run_check=False)
) )
else: else:
assert ( assert spec is None, (
spec is None f"Non-tensor output {out} expects None placements but received {spec}!"
), f"Non-tensor output {out} expects None placements but received {spec}!" )
flat_dist_out.append(out) flat_dist_out.append(out)

View File

@ -188,9 +188,14 @@ def _mark_sharding(
""" """
Mark the sharding strategy for each node in the graph module. Mark the sharding strategy for each node in the graph module.
""" """
placement_strategies: dict[ placement_strategies: dict[Node, PlacementStrategy] = (
Node, PlacementStrategy _mark_tensor_parallel_shardings(
] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements) gm,
graph_signature,
mesh,
parameter_placements,
)
)
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == "placeholder": if node.op == "placeholder":
@ -202,9 +207,9 @@ def _mark_sharding(
elif node.op == "call_function": elif node.op == "call_function":
if node.target == operator.getitem: if node.target == operator.getitem:
input_nodes = node.all_input_nodes input_nodes = node.all_input_nodes
assert ( assert len(input_nodes) == 1, (
len(input_nodes) == 1 f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}" )
arg_strategy = placement_strategies[input_nodes[0]] arg_strategy = placement_strategies[input_nodes[0]]
placement_strategies[node] = _create_placement_strategy( placement_strategies[node] = _create_placement_strategy(
node, node,

View File

@ -328,7 +328,9 @@ class DTensorExtensions(FSDPExtensions):
self.device_handle = device_handle self.device_handle = device_handle
# we have to use the dynamo disable this way to disable dynamo as the decorater way would # we have to use the dynamo disable this way to disable dynamo as the decorater way would
# trigger build failure with torch deploy... # trigger build failure with torch deploy...
self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign] self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign]
self.post_unflatten_transform
)
def pre_flatten_transform( def pre_flatten_transform(
self, self,

View File

@ -64,9 +64,7 @@ def input_reshard(
return module return module
def _pack_hook_tp( def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401
mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor
) -> Any: # noqa: D401
"""Hook function called after FWD to shard input.""" """Hook function called after FWD to shard input."""
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
@ -84,9 +82,7 @@ def _pack_hook_tp(
return x return x
def _unpack_hook_tp( def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401
mesh: DeviceMesh, input_reshard_dim: int, x: Any
) -> torch.Tensor: # noqa: D401
"""Hook function called before activation recomputing in BWD to restore input.""" """Hook function called before activation recomputing in BWD to restore input."""
if ( if (
isinstance(x, DTensor) isinstance(x, DTensor)

View File

@ -38,8 +38,7 @@ class ParallelStyle(ABC):
src_data_rank: Optional[int] = 0 src_data_rank: Optional[int] = 0
@abstractmethod @abstractmethod
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...
...
class ColwiseParallel(ParallelStyle): class ColwiseParallel(ParallelStyle):
@ -467,19 +466,21 @@ class PrepareModuleInput(ParallelStyle):
) )
self.use_local_output = use_local_output self.use_local_output = use_local_output
if self.input_layouts is not None: if self.input_layouts is not None:
assert ( assert self.desired_input_layouts is not None, (
self.desired_input_layouts is not None "desired module inputs should not be None!"
), "desired module inputs should not be None!" )
assert len(self.input_layouts) == len( assert len(self.input_layouts) == len(self.desired_input_layouts), (
self.desired_input_layouts "input_layouts and desired_input_layouts should have same length!"
), "input_layouts and desired_input_layouts should have same length!" )
self.with_kwargs = input_kwarg_layouts is not None self.with_kwargs = input_kwarg_layouts is not None
self.input_kwarg_layouts = input_kwarg_layouts or {} self.input_kwarg_layouts = input_kwarg_layouts or {}
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
if self.with_kwargs: if self.with_kwargs:
assert len(self.input_kwarg_layouts) == len( assert len(self.input_kwarg_layouts) == len(
self.desired_input_kwarg_layouts self.desired_input_kwarg_layouts
), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" ), (
"input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
)
def _prepare_input_arg( def _prepare_input_arg(
self, self,
@ -494,9 +495,9 @@ class PrepareModuleInput(ParallelStyle):
# assert inp.placements[0] == input_layout # assert inp.placements[0] == input_layout
dt_inp = input dt_inp = input
else: else:
assert isinstance( assert isinstance(input, torch.Tensor), (
input, torch.Tensor "expecting input to be a torch.Tensor!"
), "expecting input to be a torch.Tensor!" )
dt_inp = DTensor.from_local( dt_inp = DTensor.from_local(
input, mesh, (input_layout,), run_check=False input, mesh, (input_layout,), run_check=False
) )
@ -517,9 +518,9 @@ class PrepareModuleInput(ParallelStyle):
if len(inputs) != len(self.input_layouts): if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!") raise ValueError("module inputs and input_layouts should have same length!")
assert ( assert self.desired_input_layouts is not None, (
self.desired_input_layouts is not None "desired module inputs should not be None!"
), "desired module inputs should not be None!" )
for inp, input_layout, desired_layout in zip( for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts inputs, self.input_layouts, self.desired_input_layouts
): ):
@ -551,7 +552,9 @@ class PrepareModuleInput(ParallelStyle):
with_kwargs=True, with_kwargs=True,
) # type: ignore[misc] ) # type: ignore[misc]
else: else:
module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] module.register_forward_pre_hook(
lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
) # type: ignore[misc, call-arg]
return module return module
@ -611,9 +614,9 @@ class PrepareModuleOutput(ParallelStyle):
else desired_output_layouts else desired_output_layouts
) )
self.use_local_output = use_local_output self.use_local_output = use_local_output
assert len(self.output_layouts) == len( assert len(self.output_layouts) == len(self.desired_output_layouts), (
self.desired_output_layouts "output_layouts and desired_output_layouts should have same length!"
), "output_layouts and desired_output_layouts should have same length!" )
def _prepare_out_fn(self, outputs, device_mesh): def _prepare_out_fn(self, outputs, device_mesh):
prepared_outputs = [] prepared_outputs = []
@ -649,5 +652,7 @@ class PrepareModuleOutput(ParallelStyle):
return tuple(prepared_outputs) return tuple(prepared_outputs)
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg] module.register_forward_hook(
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
) # type: ignore[misc, call-arg]
return module return module

View File

@ -83,9 +83,9 @@ class Shard(Placement):
few ranks before calling the collectives (i.e. scatter/all_gather, etc.). few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
This is because collectives usually require equal size tensor inputs This is because collectives usually require equal size tensor inputs
""" """
assert ( assert self.dim <= tensor.ndim, (
self.dim <= tensor.ndim f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" )
# chunk tensor over dimension `dim` into n slices # chunk tensor over dimension `dim` into n slices
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
@ -468,9 +468,9 @@ class _StridedShard(Shard):
""" """
TODO: currently _StridedShard does not support padding TODO: currently _StridedShard does not support padding
""" """
assert ( assert self.dim <= tensor.ndim, (
self.dim <= tensor.ndim f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" )
total_split = num_chunks * self.split_factor total_split = num_chunks * self.split_factor
assert tensor.size(self.dim) % total_split == 0, ( assert tensor.size(self.dim) % total_split == 0, (

Some files were not shown because too many files have changed in this diff Show More