[BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
This commit is contained in:
Xuehai Pan
2025-02-28 11:10:58 +08:00
committed by PyTorch MergeBot
parent 4e160d5fd9
commit 995df34b19
143 changed files with 920 additions and 774 deletions

View File

@ -59,7 +59,6 @@ USE_BLACK_FILELIST = re.compile(
# torch/[a-c]*/**
"torch/[a-c]*/**",
# torch/d*/**
"torch/d*/**",
# torch/[e-n]*/**
"torch/[e-n]*/**",
# torch/optim/**

View File

@ -36,11 +36,9 @@ _M = TypeVar("_M", nn.Module, list[nn.Module])
class _ContractFn(Protocol, Generic[_P, _T, _TState]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
...
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
def state(self, module: nn.Module) -> _TState:
...
def state(self, module: nn.Module) -> _TState: ...
def contract(
@ -92,7 +90,7 @@ def contract(
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
@wraps(state_cls) # type: ignore[arg-type]
def inner(
func: Callable[Concatenate[_M, _P], _M]
func: Callable[Concatenate[_M, _P], _M],
) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
@wraps(func)
def wrapper(
@ -232,9 +230,7 @@ def contract(
return module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY,
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
).get(
func
) # type: ignore[call-overload]
).get(func) # type: ignore[call-overload]
wrapper.state = get_state # type: ignore[attr-defined]

View File

@ -274,9 +274,9 @@ def reduce_scatter_tensor(
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
assert (
self.size(scatter_dim) % group_size == 0
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
assert self.size(scatter_dim) % group_size == 0, (
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
)
if scatter_dim != 0:
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
self = torch.cat(tensor_list)
@ -313,9 +313,9 @@ def reduce_scatter_tensor_autograd(
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
assert (
self.size(scatter_dim) % group_size == 0
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
assert self.size(scatter_dim) % group_size == 0, (
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
)
if scatter_dim != 0:
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
self = torch.cat(tensor_list)
@ -414,9 +414,9 @@ def reduce_scatter_tensor_coalesced(
assert len(scatter_dim) == len(inputs)
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
assert (
tensor.size(dim) % group_size == 0
), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
assert tensor.size(dim) % group_size == 0, (
f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
)
if dim != 0:
tensor_list = torch.chunk(tensor, group_size, dim=dim)
inputs[idx] = torch.cat(tensor_list)
@ -574,6 +574,7 @@ class AsyncCollectiveTensor(torch.Tensor):
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
return _maybe_wrap_tensor(tensor)
"""
elem: torch.Tensor
completed: bool
@ -726,9 +727,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
group_size = len(rankset)
tag = tag or c10d._get_group_tag(group)
elif isinstance(group, DeviceMesh):
assert (
group.ndim == 1
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
assert group.ndim == 1, (
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
# TODO: it should run collective in the whole mesh instead of dim 0
tag, rankset, _ = group._dim_group_infos[0]
group_size = len(rankset)
@ -763,9 +764,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
elif isinstance(group, str):
return group
elif isinstance(group, DeviceMesh):
assert (
group.ndim == 1
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
assert group.ndim == 1, (
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
return group._dim_group_infos[0][2]
elif isinstance(group, tuple):
if (
@ -837,11 +838,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
return y
@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y)
return y * y
x = torch.ones(1280, 1280, device="cuda") + self.rank
# the context manager ensures that `wait_tensor(y)` will wait on the correct work object
with allow_inflight_collective_as_graph_input_ctx():
@ -1057,9 +1060,9 @@ def all_gather_tensor_inplace(
tag: str = "",
gather_dim: int = 0,
):
assert (
not async_op
), "Can't remap async version of inplace op to functional collective"
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
assert group is not None
@ -1076,9 +1079,9 @@ def reduce_scatter_tensor_inplace(
scatter_dim: int = 0,
tag: str = "",
):
assert (
not async_op
), "Can't remap async version of inplace op to functional collective"
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
assert group is not None
@ -1105,9 +1108,9 @@ def all_reduce_inplace(
async_op: bool = False,
tag: str = "",
):
assert (
not async_op
), "Can't remap async version of inplace op to functional collective"
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
assert group is not None
@ -1124,9 +1127,9 @@ def all_to_all_inplace(
async_op=False,
tag: str = "",
):
assert (
not async_op
), "Can't remap async version of inplace op to functional collective"
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
assert group is not None
@ -1149,12 +1152,12 @@ def all_gather_inplace(
async_op=False,
tag: str = "",
):
assert (
not async_op
), "Can't remap async version of inplace op to functional collective"
assert all(
t.size(0) == tensor.size(0) for t in tensor_list
), "Remapping variable size all_gather is not yet supported"
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
assert all(t.size(0) == tensor.size(0) for t in tensor_list), (
"Remapping variable size all_gather is not yet supported"
)
group = group or dist.group.WORLD
assert group is not None

View File

@ -592,7 +592,9 @@ class ShardedTensor(ShardedTensorBase):
assert (
isinstance(device, torch.device)
and device.index == torch.cuda.current_device()
), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
), (
"""Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
)
current_device = torch.device(torch.cuda.current_device())
# returns a copy of ShardedTensor on CUDA current device
@ -831,7 +833,9 @@ class ShardedTensor(ShardedTensorBase):
"rank:1/cuda:1",
],
)
>>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4])
>>> st = ShardedTensor._init_from_local_tensor(
... local_tensor, sharding_spec, [2, 4]
... )
>>> st
ShardedTensor(
ShardedTensorMetadata(

View File

@ -219,9 +219,7 @@ def reshard_local_shard(
output_tensor_size = list(st_size)
output_tensor_size[current_sharding_dim] = sharded_dim_size
output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
output_tensor_list[
placement.rank()
] = torch.empty( # type: ignore[union-attr, index]
output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index]
output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
)
indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type]

View File

@ -16,6 +16,6 @@ with warnings.catch_warnings():
stacklevel=2,
)
sys.modules[
"torch.distributed._sharded_tensor"
] = torch.distributed._shard.sharded_tensor
sys.modules["torch.distributed._sharded_tensor"] = (
torch.distributed._shard.sharded_tensor
)

View File

@ -67,7 +67,7 @@ def _all_gather_sharded_tensor(
class CompanionMismatch(Exception):
...
pass
def _iterate_state_dict(
@ -409,9 +409,9 @@ def _create_cpu_state_dict(
def unpin_memory(t):
succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
assert (
succ == 0
), f"Unpinning shared memory failed with error-code: {succ}"
assert succ == 0, (
f"Unpinning shared memory failed with error-code: {succ}"
)
weakref.finalize(t, unpin_memory, t)
succ = int(
@ -421,9 +421,9 @@ def _create_cpu_state_dict(
1, # lines up with 'cudaHostRegisterPortable'
)
)
assert (
succ == 0
), f"Pinning shared memory failed with error-code: {succ}"
assert succ == 0, (
f"Pinning shared memory failed with error-code: {succ}"
)
return t
elif pin_memory:
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()

View File

@ -1525,8 +1525,7 @@ if TYPE_CHECKING:
@overload
def empty(
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
) -> torch.Tensor:
...
) -> torch.Tensor: ...
@overload
@ -1535,8 +1534,7 @@ def empty(
*,
dtype: Optional[_dtype] = None,
device: Optional[_device] = None,
) -> torch.Tensor:
...
) -> torch.Tensor: ...
def empty( # type: ignore[misc]

View File

@ -6,6 +6,7 @@ we keep the old import path starts with `_tensor` for
backward compatibility. We will remove this folder once
we resolve all the BC issues.
"""
import sys
from importlib import import_module

View File

@ -153,7 +153,7 @@ class FSDPMemTracker(MemTracker):
loss.backward()
optimizer.step()
fmt.display_snapshot("peak")
fmt.display_modulewise_snapshots(depth = 3, units = "MB")
fmt.display_modulewise_snapshots(depth=3, units="MB")
"""

View File

@ -379,7 +379,7 @@ class MemTracker(TorchDispatchMode):
optimizer.step()
optimizer.zero_grad()
mt.display_snapshot("peak")
mt.display_modulewise_snapshots(depth = 3, units = "MiB")
mt.display_modulewise_snapshots(depth=3, units="MiB")
Known Limitations:
- The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``.

View File

@ -42,6 +42,7 @@ class ModTracker:
def my_linear(m1, m2, bias):
print(f"Current modules: {tracker.parents}")
return torch.mm(m1, m2.t()) + bias
torch.nn.functional.linear = my_linear
mod(torch.rand(2, 2))

View File

@ -255,9 +255,9 @@ class RuntimeEstimator(TorchDispatchMode):
Tuple[Any, float]: A tuple containing the result of the function and
the mean operation time in milliseconds.
"""
assert isinstance(
cls.fake_mode, FakeTensorMode
), "Initialize/Assign FakeTensorMode before using this function"
assert isinstance(cls.fake_mode, FakeTensorMode), (
"Initialize/Assign FakeTensorMode before using this function"
)
mean_op_time = 0.0
if func._overloadpacket not in _VIEW_OPS:
try:
@ -289,9 +289,9 @@ class RuntimeEstimator(TorchDispatchMode):
Tuple[Any, float]: A tuple containing the result of the function and
the mean operation time in milliseconds.
"""
assert (
torch.cuda.is_available()
), "Roofline estimation needs to access CUDA capabilities to make estimations"
assert torch.cuda.is_available(), (
"Roofline estimation needs to access CUDA capabilities to make estimations"
)
def get_num_bytes(t: torch.Tensor) -> int:
"""
@ -324,9 +324,9 @@ class RuntimeEstimator(TorchDispatchMode):
float: The estimated compute time in nanoseconds.
"""
if func_packet in flop_registry:
assert (
len(out_dtypes) == 1
), f"Only support single out dtype got {out_dtypes} for {func_packet}"
assert len(out_dtypes) == 1, (
f"Only support single out dtype got {out_dtypes} for {func_packet}"
)
dtype = out_dtypes.pop()
# This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
peak_gpu_flops = get_device_tflops(dtype) * 1e15
@ -487,9 +487,9 @@ class RuntimeEstimator(TorchDispatchMode):
def __enter__(self) -> Self:
fake_mode = active_fake_mode()
assert isinstance(
fake_mode, FakeTensorMode
), "No FakeTensorMode found, designed to used under FakeTensorMode"
assert isinstance(fake_mode, FakeTensorMode), (
"No FakeTensorMode found, designed to used under FakeTensorMode"
)
RuntimeEstimator.fake_mode = fake_mode
self.total_runtime = 0.0
self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0))

View File

@ -245,7 +245,7 @@ class SACEstimator(TorchDispatchMode):
with FakeTensorMode():
module = ...
inp = ...
with sac_estimator('operator-level-cost-model'):
with sac_estimator("operator-level-cost-model"):
output = module(inp)
sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True)
"""
@ -442,9 +442,9 @@ class SACEstimator(TorchDispatchMode):
out_storages_cpu.update(_get_untyped_storages(o))
# Check if there's more than 1 CUDA device
assert (
len(cuda_devices) <= 1
), f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
assert len(cuda_devices) <= 1, (
f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
)
# 2. Get the memory consumed by output
nbytes_cuda = sum(
@ -484,9 +484,9 @@ class SACEstimator(TorchDispatchMode):
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
acm_stats.sac_metadata.append(acm)
else:
assert (
mod_fqn == "Global"
), f"Module {mod_fqn} not found in AC Mod Stats"
assert mod_fqn == "Global", (
f"Module {mod_fqn} not found in AC Mod Stats"
)
self._sac_metadata.append(acm)
return out
@ -979,9 +979,9 @@ class SACEstimator(TorchDispatchMode):
def __enter__(self) -> Self: # type: ignore[no-untyped-def]
fake_mode = active_fake_mode()
assert isinstance(
fake_mode, FakeTensorMode
), "SAC Estimator should be called in FakeTensorMode"
assert isinstance(fake_mode, FakeTensorMode), (
"SAC Estimator should be called in FakeTensorMode"
)
RuntimeEstimator.fake_mode = fake_mode
self._mod_tracker.register_user_hooks(
pre_fw_hook=self._pre_fw_hook,

View File

@ -38,9 +38,9 @@ def _perform_local_step(
"""
overlap_info = zero._overlap_info
bucket_index = bucket.index()
assert (
len(zero.optim.param_groups) == 1
), "Overlapping DDP with ZeRO only supports a single parameter group"
assert len(zero.optim.param_groups) == 1, (
"Overlapping DDP with ZeRO only supports a single parameter group"
)
# Construct the `gradients` input for the local optimizer step, which
# expects `None` in a list position to indicate that the corresponding
@ -49,9 +49,9 @@ def _perform_local_step(
gradients: list[Optional[torch.Tensor]] = [
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
]
assert (
bucket_index in overlap_info.offsets
), f"Bucket index {bucket_index} was not assigned to rank {rank}"
assert bucket_index in overlap_info.offsets, (
f"Bucket index {bucket_index} was not assigned to rank {rank}"
)
gradients_offset = overlap_info.offsets[bucket_index]
bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
bucket_offset = bucket_assignment.offset
@ -77,13 +77,13 @@ def _broadcast_bucket(
:class:`ZeroRedundancyOptimizer` instance.
"""
overlap_info = zero._overlap_info
assert (
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
), "`assigned_ranks_per_bucket` is not fully constructed"
assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
"`assigned_ranks_per_bucket` is not fully constructed"
)
# Sort to ensure the same ordering across ranks
assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
assert len(assigned_ranks) > 0, (
f"Bucket {bucket_index} should be " "assigned to at least one rank"
f"Bucket {bucket_index} should be assigned to at least one rank"
)
for assigned_rank in assigned_ranks:
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
@ -273,9 +273,9 @@ def hook_with_zero_step(
rank = zero.global_rank
assert overlap_info.status == _OverlapStatus.INITIALIZED
assert (
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
), "`assigned_ranks_per_bucket` is not fully constructed"
assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
"`assigned_ranks_per_bucket` is not fully constructed"
)
assigned_to_bucket = (
rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
)
@ -288,9 +288,9 @@ def hook_with_zero_step(
# Check that buckets are indexed incrementally starting from 0 in the
# order of their autograd hooks firing
if len(overlap_info.bucket_indices_seen) > 0:
assert (
overlap_info.bucket_indices_seen[-1] == bucket_index - 1
), "Bucket indices are not in incremental order"
assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, (
"Bucket indices are not in incremental order"
)
else:
assert bucket_index == 0, "Bucket indices do not start from 0"
overlap_info.bucket_indices_seen.append(bucket_index)

View File

@ -129,7 +129,7 @@ def bf16_compress_hook(
def fp16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
"""
Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
@ -167,7 +167,7 @@ def fp16_compress_wrapper(
def bf16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
"""
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.

View File

@ -223,8 +223,7 @@ class Join:
self._rank = dist.get_rank(self._process_group)
self._device = device
def __enter__(self):
...
def __enter__(self): ...
def __exit__(
self,

View File

@ -52,7 +52,10 @@ def average_parameters(
def get_params_to_average(
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]]
params: Union[
Iterable[torch.nn.Parameter],
Iterable[dict[str, torch.nn.Parameter]],
],
):
"""
Return a list of parameters that need to average.

View File

@ -550,9 +550,7 @@ def create_default_global_save_plan(
new_item = dataclasses.replace(item, index=new_index)
new_items.append(new_item)
assert (
item.tensor_data.chunk is not None
), f"""
assert item.tensor_data.chunk is not None, f"""
Cannot create MD for tensor without bounds.
FQN: {item.index.fqn}
"""

View File

@ -414,41 +414,33 @@ class FileSystemBase(ABC):
@abstractmethod
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
...
) -> Generator[io.IOBase, None, None]: ...
@abstractmethod
def concat_path(
self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]:
...
) -> Union[str, os.PathLike]: ...
@abstractmethod
def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
...
) -> None: ...
@abstractmethod
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
...
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ...
@abstractmethod
def mkdir(self, path: Union[str, os.PathLike]) -> None:
...
def mkdir(self, path: Union[str, os.PathLike]) -> None: ...
@classmethod
@abstractmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
...
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ...
@abstractmethod
def exists(self, path: Union[str, os.PathLike]) -> bool:
...
def exists(self, path: Union[str, os.PathLike]) -> bool: ...
@abstractmethod
def rm_file(self, path: Union[str, os.PathLike]) -> None:
...
def rm_file(self, path: Union[str, os.PathLike]) -> None: ...
class FileSystem(FileSystemBase):
@ -512,7 +504,6 @@ class FileSystem(FileSystemBase):
class _FileSystemWriter(StorageWriter):
"""
Basic implementation of StorageWriter using file IO.
@ -800,9 +791,9 @@ class FileSystemReader(StorageReader):
)
target_tensor = planner.resolve_tensor(req).detach()
assert (
target_tensor.size() == tensor.size()
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)

View File

@ -135,12 +135,12 @@ def _get_state_dict_2d_layout(
for key, value in state_dict.items():
specs[key] = (None, value.size())
if _is_nested_tensor(value):
assert (
len(value.local_shards()) == 1
), "Cannot handle ST with multiple shards"
assert isinstance(
value, ShardedTensor
), "Can only handle nested ShardedTensor"
assert len(value.local_shards()) == 1, (
"Cannot handle ST with multiple shards"
)
assert isinstance(value, ShardedTensor), (
"Can only handle nested ShardedTensor"
)
shard = value.local_shards()[0]
specs[key] = (
shard.metadata.shard_offsets,

View File

@ -151,7 +151,7 @@ class SavePlanner(abc.ABC):
>>> storage_meta: Optional[StorageMeta],
>>> is_coordinator: bool,
>>> ) -> None:
>>> # prefix all keys with `foo_``
>>> # prefix all keys with `foo_``
>>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
@ -175,8 +175,8 @@ class SavePlanner(abc.ABC):
>>> from itertools import zip_longest
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>> # This sample doesn't handle ShardedTensors
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>> # This sample doesn't handle ShardedTensors
>>> def create_global_plan(self, all_plans):
>>> iters = [iter(all_plans[0].items)] * len(all_plans)
>>> items_per_rank = [
@ -347,7 +347,7 @@ class LoadPlanner:
>>> self.is_coordinator = is_coordinator
>>>
>>> def load_bytes(self, read_item, value):
>>> # Remove the "foo_" prefix
>>> # Remove the "foo_" prefix
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)

View File

@ -140,10 +140,12 @@ class StateDictOptions:
@dataclass
class _StateDictInfo(StateDictOptions):
fqn_param_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
Union[str, torch.Tensor],
Union[FQNS_T, torch.Tensor],
] = field(default_factory=dict)
shared_params_mapping: dict[
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
Union[str, torch.Tensor],
Union[FQNS_T, torch.Tensor],
] = field(default_factory=dict)
submodule_prefixes: set[str] = field(default_factory=set)
handle_model: bool = True
@ -1140,7 +1142,9 @@ def get_state_dict(
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
... fsdp_model, fsdp_optim
... )
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
>>> # the asserts will fail.

View File

@ -125,7 +125,9 @@ def load(
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
... "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.load_state_dict(
>>> state_dict=model_state_dict,

View File

@ -127,7 +127,9 @@ def save(
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
... "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.save(
>>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer,
@ -206,7 +208,9 @@ def async_save(
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
... "/checkpoint/1"
... )
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer,
@ -223,7 +227,9 @@ def async_save(
pg = process_group or _get_default_group()
assert (
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
), (
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
)
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)

View File

@ -32,7 +32,7 @@ R = TypeVar("R")
def _get_failure_dict(
results: list[Union[T, WRAPPED_EXCEPTION]]
results: list[Union[T, WRAPPED_EXCEPTION]],
) -> dict[int, WRAPPED_EXCEPTION]:
return cast(
dict[int, WRAPPED_EXCEPTION],

View File

@ -221,8 +221,12 @@ else:
if cur_rank in mesh_nd:
res_flattened_mesh = flattened_mesh
self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined]
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined]
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
res_flattened_mesh # type: ignore[possibly-undefined]
)
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
flatten_dims_in_root
) # type: ignore[possibly-undefined]
return res_flattened_mesh
@ -242,9 +246,9 @@ else:
root_mesh = self.get_root_mesh(device_mesh)
child_mesh_dim_names = device_mesh.mesh_dim_names
if root_mesh and child_mesh_dim_names:
assert (
len(child_mesh_dim_names) == 1
), "The submesh can only be a 1D mesh."
assert len(child_mesh_dim_names) == 1, (
"The submesh can only be a 1D mesh."
)
child_mesh_dim_name = child_mesh_dim_names[0]
return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name)
return None
@ -763,7 +767,9 @@ else:
root_mesh, None
)
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index]
dim_group_infos = root_to_flatten_mapping[
mesh_dim # type: ignore[index]
]._dim_group_infos[0][:2]
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
else:
mesh_dim = (
@ -905,9 +911,9 @@ else:
mesh_dim = 0
mesh_dim_group = not_none(self.get_group(mesh_dim))
assert isinstance(
mesh_dim_group, ProcessGroup
), "We expect ProcessGroup before calling `get_rank`!"
assert isinstance(mesh_dim_group, ProcessGroup), (
"We expect ProcessGroup before calling `get_rank`!"
)
return not_none(get_rank(mesh_dim_group))
def get_coordinate(self) -> Optional[list[int]]:

View File

@ -334,12 +334,12 @@ class Backend(str): # noqa: SLOT000
# Allow UCC plugin if Pytorch is not built with native support.
# TODO: remove this exception once UCC plugin is fully deprecated.
if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()):
assert not hasattr(
Backend, name.upper()
), f"{name.upper()} c10d backend already exist"
assert (
name.upper() not in Backend._plugins
), f"{name.upper()} c10d backend creator function already exist"
assert not hasattr(Backend, name.upper()), (
f"{name.upper()} c10d backend already exist"
)
assert name.upper() not in Backend._plugins, (
f"{name.upper()} c10d backend creator function already exist"
)
setattr(Backend, name.upper(), name.lower())
Backend.backend_list.append(name.lower())
@ -1650,9 +1650,9 @@ def init_process_group(
if "torch._dynamo" in sys.modules:
torch._dynamo.trace_rules.clear_lru_cache()
assert (store is None) or (
init_method is None
), "Cannot specify both init_method and store."
assert (store is None) or (init_method is None), (
"Cannot specify both init_method and store."
)
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
@ -1734,7 +1734,10 @@ def init_process_group(
)
_update_default_pg(default_pg)
_world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
_world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index]
i: i
for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined]
}
_backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
_default_pg_init_method = init_method
@ -1959,9 +1962,9 @@ def _new_process_group_helper(
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL built in")
if backend_options is not None:
assert isinstance(
backend_options, ProcessGroupNCCL.Options
), "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
assert isinstance(backend_options, ProcessGroupNCCL.Options), (
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
)
if backend_options._timeout != timeout:
warnings.warn(
"backend_options._timeout was specified, "
@ -2001,9 +2004,9 @@ def _new_process_group_helper(
)
backend_type = ProcessGroup.BackendType.XCCL
else:
assert (
backend_str.upper() in Backend._plugins
), f"Unknown c10d backend type {backend_str.upper()}"
assert backend_str.upper() in Backend._plugins, (
f"Unknown c10d backend type {backend_str.upper()}"
)
backend_plugin = Backend._plugins[backend_str.upper()]
creator_fn = backend_plugin.creator_fn
@ -2630,8 +2633,10 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
>>> # xdoctest: +SKIP("no rank")
>>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
>>> recv_tensor = torch.randn(2, dtype=torch.float32)
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
>>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size)
>>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
>>> recv_op = dist.P2POp(
... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
... )
>>> reqs = batch_isend_irecv([send_op, recv_op])
>>> for req in reqs:
>>> req.wait()
@ -2758,7 +2763,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
>>> # xdoctest: +SKIP("no rank")
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> device = torch.device(f'cuda:{rank}')
>>> device = torch.device(f"cuda:{rank}")
>>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>> tensor
tensor([1, 2], device='cuda:0') # Rank 0
@ -2770,7 +2775,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
>>> tensor = torch.tensor(
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
@ -3380,9 +3387,9 @@ def recv_object_list(
)
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
assert (
rank_sizes == rank_objects
), "Mismatch in return ranks for object sizes and objects."
assert rank_sizes == rank_objects, (
"Mismatch in return ranks for object sizes and objects."
)
# Deserialize objects using their stored sizes.
offset = 0
for i, obj_size in enumerate(object_sizes_tensor):
@ -3673,8 +3680,10 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
>>> # xdoctest: +SKIP("need process group init")
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> device = torch.device(f'cuda:{rank}')
>>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)]
>>> device = torch.device(f"cuda:{rank}")
>>> tensor_list = [
... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)
... ]
>>> tensor_list
[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
@ -3689,11 +3698,15 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)]
>>> tensor_list = [
... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)
... ]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
>>> tensor = torch.tensor(
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
@ -3769,7 +3782,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
>>> # xdoctest: +SKIP("need process group init")
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
>>> # We have two ranks.
>>> device = torch.device(f'cuda:{rank}')
>>> device = torch.device(f"cuda:{rank}")
>>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>> tensor_in
tensor([1, 2], device='cuda:0') # Rank 0
@ -3969,8 +3982,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
)
elif gather_list:
raise ValueError(
"Argument ``gather_list`` must NOT be specified "
"on non-destination ranks."
"Argument ``gather_list`` must NOT be specified on non-destination ranks."
)
@ -4141,8 +4153,7 @@ def scatter(
else:
if scatter_list:
raise ValueError(
"Argument ``scatter_list`` must NOT be specified "
"on non-source ranks."
"Argument ``scatter_list`` must NOT be specified on non-source ranks."
)
input_tensors = []
output_tensors = [tensor]
@ -4225,7 +4236,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
>>> # xdoctest: +SKIP("need process group init")
>>> # All tensors below are of torch.int64 dtype and on CUDA devices.
>>> # We have two ranks.
>>> device = torch.device(f'cuda:{rank}')
>>> device = torch.device(f"cuda:{rank}")
>>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
>>> # Input in concatenation form
>>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
@ -4381,7 +4392,7 @@ def all_to_all_single(
>>> # Essentially, it is similar to following operation:
>>> scatter_list = list(input.chunk(world_size))
>>> gather_list = list(output.chunk(world_size))
>>> gather_list = list(output.chunk(world_size))
>>> for i in range(world_size):
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
@ -4411,7 +4422,9 @@ def all_to_all_single(
>>> # Another example with tensors of torch.cfloat type.
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
>>> input = torch.tensor(
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>> input
tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
@ -4510,7 +4523,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
>>> # Essentially, it is similar to following operation:
>>> scatter_list = input
>>> gather_list = output
>>> gather_list = output
>>> for i in range(world_size):
>>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
@ -4544,7 +4557,9 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3
>>> # Another example with tensors of torch.cfloat type.
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
>>> input = torch.tensor(
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>> input = list(input.chunk(4))
>>> input
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
@ -4882,9 +4897,9 @@ def split_group(
backend_config = BackendConfig(backend)
if pg_options is not None:
assert isinstance(
pg_options, ProcessGroupNCCL.Options
), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
assert isinstance(pg_options, ProcessGroupNCCL.Options), (
"Expected pg_options argument to be of type ProcessGroupNCCL.Options"
)
else:
# default pg_options same as the parent process group
pg_options = parent_backend.options
@ -5086,9 +5101,9 @@ def _new_group_with_tag(
if device_id is None:
device_id = default_pg.bound_device_id
elif default_pg.bound_device_id is not None:
assert (
device_id == default_pg.bound_device_id
), "Mismatched bound device between new pg and the default pg."
assert device_id == default_pg.bound_device_id, (
"Mismatched bound device between new pg and the default pg."
)
default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()
@ -5408,9 +5423,9 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
def _find_or_create_pg_by_ranks_and_tag(
tag: str, ranks: list[int], stride: int
) -> ProcessGroup:
assert (
len(ranks) % stride == 0
), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
assert len(ranks) % stride == 0, (
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
)
my_rank = get_rank()
my_ranks = None

View File

@ -40,8 +40,9 @@ def worker_main() -> Generator[None, None, None]:
def main():
pass
if __name__=="__main__":
main()
if __name__ == "__main__":
main()
"""
with ExitStack() as stack:

View File

@ -14,7 +14,10 @@ Example of usage:
::
from torch.distributed.elastic import events
event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
event = events.Event(
name="test_event", source=events.EventSource.WORKER, metadata={...}
)
events.get_logging_handler(destination="console").info(event)
"""

View File

@ -52,11 +52,12 @@ The example below measures the latency for the ``calculate()`` function.
metrics.configure(metrics.NullMetricsHandler())
metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
def my_method():
start = time.time()
calculate()
end = time.time()
metrics.put_metric("calculate_latency", int(end-start), "my_module")
start = time.time()
calculate()
end = time.time()
metrics.put_metric("calculate_latency", int(end - start), "my_module")
You may also use the torch.distributed.elastic.metrics.prof` decorator
to conveniently and succinctly profile functions
@ -70,15 +71,16 @@ to conveniently and succinctly profile functions
metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
@metrics.prof
def foo():
pass
pass
class Bar():
@metrics.prof
def baz():
pass
class Bar:
@metrics.prof
def baz():
pass
``@metrics.prof`` will publish the following metrics
::
@ -102,8 +104,8 @@ console.
import torch.distributed.elastic.metrics as metrics
metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic")
metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app")
metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic")
metrics.configure(metrics.ConsoleMetricHandler(), group="my_app")
**Writing a Custom Metric Handler**:
@ -117,13 +119,15 @@ Below is a toy example that prints the metrics to ``stdout``
import torch.distributed.elastic.metrics as metrics
class StdoutMetricHandler(metrics.MetricHandler):
def emit(self, metric_data):
ts = metric_data.timestamp
group = metric_data.group_name
name = metric_data.name
value = metric_data.value
print(f"[{ts}][{group}]: {name}={value}")
def emit(self, metric_data):
ts = metric_data.timestamp
group = metric_data.group_name
name = metric_data.name
value = metric_data.value
print(f"[{ts}][{group}]: {name}={value}")
metrics.configure(StdoutMetricHandler(), group="my_app")

View File

@ -123,6 +123,7 @@ def prof(fn=None, group: str = "torchelastic"):
def x():
pass
@metrics.prof(group="agent")
def y():
pass

View File

@ -20,22 +20,23 @@ Usage 1: Launching two trainers as a function
from torch.distributed.elastic.multiprocessing import Std, start_processes
def trainer(a, b, c):
pass # train
pass # train
# runs two trainers
# LOCAL_RANK=0 trainer(1,2,3)
# LOCAL_RANK=1 trainer(4,5,6)
ctx = start_processes(
name="trainer",
entrypoint=trainer,
args={0: (1,2,3), 1: (4,5,6)},
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
log_dir="/tmp/foobar",
redirects=Std.ALL, # write all worker stdout/stderr to a log file
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
)
name="trainer",
entrypoint=trainer,
args={0: (1, 2, 3), 1: (4, 5, 6)},
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
log_dir="/tmp/foobar",
redirects=Std.ALL, # write all worker stdout/stderr to a log file
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
)
# waits for all copies of trainer to finish
ctx.wait()

View File

@ -165,9 +165,11 @@ def to_map(
Example:
::
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
to_map(
{0: Std.OUT, 1: Std.OUT}, local_world_size=2
) # returns: {0: Std.OUT, 1: Std.OUT}
"""
if isinstance(val_or_map, Std):
return dict.fromkeys(range(local_world_size), val_or_map)
@ -304,7 +306,9 @@ class DefaultLogsSpecs(LogsSpecs):
if not self._run_log_dir:
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload]
attempt_log_dir = os.path.join(
self._run_log_dir, f"attempt_{restart_count}"
) # type: ignore[call-overload]
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)
@ -868,9 +872,7 @@ class SubprocessContext(PContext):
if result.is_failed():
first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
logger.error(
"failed (exitcode: %s)"
" local_rank: %s (pid: %s)"
" of binary: %s",
"failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s",
first_failure.exitcode,
first_failure.local_rank,
first_failure.pid,

View File

@ -318,14 +318,14 @@ def record(
error_handler = get_error_handler()
error_handler.initialize()
try:
foobar()
foobar()
except ChildFailedError as e:
_, failure = e.get_first_failure()
error_handler.dump_error_file(failure.error_file, failure.exitcode)
raise
_, failure = e.get_first_failure()
error_handler.dump_error_file(failure.error_file, failure.exitcode)
raise
except Exception as e:
error_handler.record_exception(e)
raise
error_handler.record_exception(e)
raise
.. important:: use this decorator once per process at the top level method,
typically this is the main method.
@ -338,8 +338,9 @@ def record(
def main():
pass
if __name__=="__main__":
main()
if __name__ == "__main__":
main()
"""
if not error_handler:

View File

@ -120,11 +120,7 @@ of the following implementations that come with PyTorch:
backend = C10dRendezvousBackend(store, "my_run_id")
rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id",
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4
)
"""

View File

@ -89,8 +89,14 @@ class RendezvousStoreInfo:
addr = local_addr or socket.getfqdn()
# When TCPStore is not shared, we fallback to get_free_port.
port = server_port or get_free_port()
store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type]
store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type]
store.set(
RendezvousStoreInfo.MASTER_ADDR_KEY,
addr.encode(encoding="UTF-8"), # type: ignore[arg-type]
)
store.set(
RendezvousStoreInfo.MASTER_PORT_KEY,
str(port).encode(encoding="UTF-8"), # type: ignore[arg-type]
)
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
port = int(

View File

@ -413,9 +413,9 @@ class EtcdRendezvous:
active_version = self.wait_for_peers(expected_version)
state = json.loads(active_version.value)
assert (
state["version"] == expected_version
), "Logic error: failed to observe version mismatch"
assert state["version"] == expected_version, (
"Logic error: failed to observe version mismatch"
)
return self.confirm_phase(expected_version, this_rank)
@ -533,9 +533,9 @@ class EtcdRendezvous:
"Rendezvous version changed. Must try join the new one."
)
assert (
len(state["participants"]) < self._num_max_workers
), "Logic error: joinable rendezvous should always have space left"
assert len(state["participants"]) < self._num_max_workers, (
"Logic error: joinable rendezvous should always have space left"
)
this_rank = len(state["participants"])
state["participants"].append(this_rank)

View File

@ -86,11 +86,15 @@ def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
def create_my_rdzv(params: RendezvousParameters):
return MyCustomRdzv(params)
return MyCustomRdzv(params)
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
my_rdzv_handler = get_rendezvous_handler(
"my_rdzv_backend_name", RendezvousParameters
)
"""
return handler_registry.create_handler(params)

View File

@ -57,10 +57,10 @@ def get_all(store, rank: int, prefix: str, world_size: int):
::
values = get_all(store, 'torchelastic/data', 3)
value1 = values[0] # retrieves the data for key torchelastic/data0
value2 = values[1] # retrieves the data for key torchelastic/data1
value3 = values[2] # retrieves the data for key torchelastic/data2
values = get_all(store, "torchelastic/data", 3)
value1 = values[0] # retrieves the data for key torchelastic/data0
value2 = values[1] # retrieves the data for key torchelastic/data1
value3 = values[2] # retrieves the data for key torchelastic/data2
"""
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])

View File

@ -2,6 +2,7 @@
"""
This file includes private common utilities for FSDP.
"""
import logging
import traceback
import warnings
@ -200,9 +201,9 @@ def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamH
# handles, meaning no entry in `_fully_sharded_module_to_handles`
if state._handle is None:
return None
assert (
module in state._fully_sharded_module_to_handle
), f"Expects a fully sharded module but got {module} on rank {state.rank}"
assert module in state._fully_sharded_module_to_handle, (
f"Expects a fully sharded module but got {module} on rank {state.rank}"
)
return state._fully_sharded_module_to_handle[module]
else:
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
@ -255,9 +256,9 @@ def _named_parameters_with_duplicates(
This API is required as some modules overwrite `named_parameters()` but do not support
`remove_duplicate`.
"""
assert (
"remove_duplicate" not in kwargs
), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
assert "remove_duplicate" not in kwargs, (
"_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
)
kwargs["remove_duplicate"] = False
try:
ret = list(module.named_parameters(**kwargs))

View File

@ -190,9 +190,9 @@ class _ExecOrderData:
return
if self.is_first_iter:
msg_prefix = "Forward order differs across ranks:"
optional_local_indices: tuple[
Optional[int], ...
] = self._get_handle_indices(handle)
optional_local_indices: tuple[Optional[int], ...] = (
self._get_handle_indices(handle)
)
device = handle.device # guaranteed to be non-CPU
num_valid_indices = sum(
(index is not None) for index in optional_local_indices
@ -250,8 +250,7 @@ class _ExecOrderData:
(
rank,
world_indices[
rank
* num_valid_indices : (rank + 1)
rank * num_valid_indices : (rank + 1)
* num_valid_indices
],
)

View File

@ -586,7 +586,10 @@ class FlatParamHandle:
)
self._fsdp_extension = fsdp_extension
self._init_flat_param_and_metadata(
params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
params,
fully_sharded_module,
self._aligned_numel,
use_orig_params, # type: ignore[arg-type]
)
self._use_unsharded_views(as_params=False)
@ -978,9 +981,9 @@ class FlatParamHandle:
shard_param_infos = self._get_shard_metadata(
unsharded_start_idx, unsharded_end_idx
)
assert (
len(shard_param_infos) == flat_param._num_params
), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
assert len(shard_param_infos) == flat_param._num_params, (
f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
)
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
@ -996,9 +999,9 @@ class FlatParamHandle:
unsharded flat parameter specifying the shard.
"""
flat_param_offsets = self._get_flat_param_offsets()
assert len(flat_param_offsets) == len(
self.flat_param._numels_with_padding
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), (
f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
)
shard_param_infos: list[_ShardParamInfo] = []
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
@ -1075,9 +1078,9 @@ class FlatParamHandle:
else:
chunk = chunks[rank]
numel_to_pad = chunks[0].numel() - chunk.numel()
assert (
numel_to_pad >= 0
), "Chunk's size should be at most the first chunk's size"
assert numel_to_pad >= 0, (
"Chunk's size should be at most the first chunk's size"
)
return chunk, numel_to_pad
@staticmethod
@ -1302,7 +1305,8 @@ class FlatParamHandle:
self._check_low_precision_shard()
flat_param = self.flat_param
_alloc_storage(
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
flat_param._mp_shard,
flat_param._local_shard.size(), # type: ignore[attr-defined]
)
# `copy_()` implicitly casts to the low precision
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
@ -1498,7 +1502,8 @@ class FlatParamHandle:
# default stream suffices since the default stream waits for the
# unshard stream.
_no_dispatch_record_stream(
self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined]
self.flat_param._mp_shard,
self._device_handle.current_stream(), # type: ignore[attr-defined]
)
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
@ -1593,8 +1598,7 @@ class FlatParamHandle:
f"but got {flat_param.grad.device}",
)
prev_iter_synced_gradients = (
flat_param.grad.size()
== flat_param._local_shard.size() # type: ignore[attr-defined]
flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined]
)
if prev_iter_synced_gradients:
# TODO (awgu): Gradient accumulation outside `no_sync()`
@ -1668,8 +1672,7 @@ class FlatParamHandle:
cast_grad_to_param_dtype_if_needed(flat_param)
else:
_p_assert(
not self.uses_sharded_strategy
or not flat_param._post_backward_called, # type: ignore[attr-defined]
not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined]
"All sharded parameters that received a gradient in the "
"post-backward should use `_saved_grad_shard`",
)
@ -2504,7 +2507,8 @@ class FlatParamHandle:
"""Return the FQNs of the parameters present in this rank's shard."""
fqns_in_shard: list[str] = []
for fqn, shard_param_info in zip(
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
self.flat_param._fqns,
self.flat_param._shard_param_infos, # type: ignore[attr-defined]
):
if shard_param_info.in_shard:
fqns_in_shard.append(fqn)
@ -2694,7 +2698,7 @@ def _safe_setattr_tensor_or_param(
def _convert_to_params(
tensors: list[Union[torch.Tensor, nn.Parameter]]
tensors: list[Union[torch.Tensor, nn.Parameter]],
) -> list[nn.Parameter]:
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]

View File

@ -374,9 +374,9 @@ def foreach_reduce(
for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)):
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
continue
assert (
unsharded_grad.size(shard_dim) % world_size == 0
), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
assert unsharded_grad.size(shard_dim) % world_size == 0, (
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
)
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
unsharded_grads[i] = torch.cat(chunks, dim=0)
padded_unsharded_sizes = tuple(

View File

@ -26,9 +26,9 @@ if torch._running_with_deploy():
else:
def detect_compiled_autograd():
assert (
not torch.compiler.is_compiling()
), "`detect_compiled_autograd()` is designed to be called in eager mode"
assert not torch.compiler.is_compiling(), (
"`detect_compiled_autograd()` is designed to be called in eager mode"
)
global _compiled_autograd_enabled
import torch._dynamo.compiled_autograd as ca

View File

@ -304,9 +304,9 @@ class FSDPParam:
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
)
split_factor = self._tp_spec.num_shards_map[shard_dim]
assert (
2 <= self._spmd_mesh.ndim <= 3
), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
assert 2 <= self._spmd_mesh.ndim <= 3, (
f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
)
self._spmd_placements: tuple[Placement, ...]
dp_shard_tp_placement = (
(
@ -520,8 +520,9 @@ class FSDPParam:
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
if hasattr(self, "_unsharded_param"):
assert compiled_autograd_enabled()
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
self._unsharded_param
with (
torch.no_grad(),
torch.autograd._unsafe_preserve_version_counter(self._unsharded_param),
):
# NOTE: Under compile, if an unsharded param goes through
# resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
@ -785,9 +786,9 @@ class FSDPParam:
assert isinstance(grad, DTensor), f"{type(grad)}"
placements = self._tp_spec.placements
if placements != grad.placements:
assert len(self._tp_spec.placements) == len(
grad.placements
), f"{self._tp_spec=} {grad.placements=}"
assert len(self._tp_spec.placements) == len(grad.placements), (
f"{self._tp_spec=} {grad.placements=}"
)
grad = grad.redistribute(placements=placements)
grad = grad._local_tensor
return grad
@ -846,9 +847,9 @@ class FSDPParam:
shard_dim = self.fsdp_placement.dim
length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
if local_tensor.size() != padded_sharded_size:
assert (
shard_dim == 0
), f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
assert shard_dim == 0, (
f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
)
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(
local_tensor

View File

@ -424,9 +424,9 @@ class FSDPParamGroup:
if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
# this means the native HSDP is not enabled,
# but user may want to have a custom HSDP setup
assert (
self._all_reduce_hook is not None
), "all reduce hook stream is specified but hook itself is missing."
assert self._all_reduce_hook is not None, (
"all reduce hook stream is specified but hook itself is missing."
)
all_reduce_stream = self._all_reduce_hook_stream
else:
all_reduce_stream = self.comm_ctx.all_reduce_stream
@ -513,9 +513,10 @@ class FSDPParamGroup:
else:
raise ValueError(f"Unknown pass type: {pass_type}")
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
f"FSDP::{pass_type}_prefetch for {target_fqn}"
), target_fsdp_param_group.use_training_state(training_state):
with (
record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"),
target_fsdp_param_group.use_training_state(training_state),
):
async_op = target_fsdp_param_group.unshard_async_op
target_fsdp_param_group.unshard(async_op)
@ -592,9 +593,9 @@ class FSDPParamGroup:
def _register_state_dict_hooks(self) -> None:
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
assert (
num_pre_save_hooks == num_pre_load_hooks
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
assert num_pre_save_hooks == num_pre_load_hooks, (
f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
)
if num_pre_save_hooks > 0:
return # already registered
modules_with_fsdp_params: set[nn.Module] = {
@ -605,12 +606,12 @@ class FSDPParamGroup:
self._to_sharded()
for module in modules_with_fsdp_params:
self._module_to_pre_save_state_dict_hook_handle[
module
] = module.register_state_dict_pre_hook(to_sharded_hook)
self._module_to_pre_load_state_dict_hook_handle[
module
] = module._register_load_state_dict_pre_hook(to_sharded_hook)
self._module_to_pre_save_state_dict_hook_handle[module] = (
module.register_state_dict_pre_hook(to_sharded_hook)
)
self._module_to_pre_load_state_dict_hook_handle[module] = (
module._register_load_state_dict_pre_hook(to_sharded_hook)
)
# Properties #
@property

View File

@ -60,8 +60,7 @@ def fully_shard(
mp_policy: MixedPrecisionPolicy = ...,
offload_policy: OffloadPolicy = ...,
ignored_params: Optional[set[nn.Parameter]] = ...,
) -> FSDPModule:
...
) -> FSDPModule: ...
@overload
@ -74,8 +73,7 @@ def fully_shard(
mp_policy: MixedPrecisionPolicy = ...,
offload_policy: OffloadPolicy = ...,
ignored_params: Optional[set[nn.Parameter]] = ...,
) -> list[FSDPModule]:
...
) -> list[FSDPModule]: ...
# The decorator adds a state object to `module` that can be accessed via

View File

@ -243,9 +243,9 @@ def _init_inter_node_process_group(
if local_rank == my_local_rank:
inter_node_pg = grp
assert (
inter_node_pg is not None
), f"{my_local_rank} expected to assign inter-node pg, but did not"
assert inter_node_pg is not None, (
f"{my_local_rank} expected to assign inter-node pg, but did not"
)
return inter_node_pg

View File

@ -145,9 +145,9 @@ def _unflatten_optim_state(
dict will need to map these entries using the proper unflattened
parameter IDs.
"""
assert (
not shard_state or to_save
), "If ``shard_state`` is True, ``to_save`` has to be True."
assert not shard_state or to_save, (
"If ``shard_state`` is True, ``to_save`` has to be True."
)
consolidated_state = _communicate_optim_state(
fsdp_param_info,
flat_param_state,
@ -218,9 +218,9 @@ def _communicate_optim_state(
):
tensor_state[state_name] = value
continue
assert (
fsdp_state.compute_device is not None
), "compute_device has not been initialized"
assert fsdp_state.compute_device is not None, (
"compute_device has not been initialized"
)
if value.device.type != fsdp_state.compute_device.type:
value = value.to(fsdp_state.compute_device)
# Assume that positive-dimension tensor optimizer state
@ -394,7 +394,10 @@ def _shard_orig_param_state(
and value.dim() > 0
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
):
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
value = value.flatten()[
intra_param_start_idx : intra_param_end_idx # type: ignore[operator]
+ 1
].clone()
new_optim_state[state_name] = value
return new_optim_state
@ -489,9 +492,9 @@ def _flatten_optim_state_dict(
if flat_state:
flat_osd_state[key] = flat_state
elif use_orig_params:
assert (
len(fqns) == 1
), f"use_orig_params is True but there are multiple FQNs, {fqns}."
assert len(fqns) == 1, (
f"use_orig_params is True but there are multiple FQNs, {fqns}."
)
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
state = optim.state.get(param, None) # type: ignore[call-overload]
if state is not None:
@ -570,14 +573,13 @@ def _flatten_optim_state(
flat_param = handle.flat_param
num_unflat_params = len(unflat_param_names)
assert num_unflat_params > 0, (
"Expects at least one unflattened parameter corresponding to the "
"flat parameter"
"Expects at least one unflattened parameter corresponding to the flat parameter"
)
unflat_param_shapes = flat_param._shapes
num_unflat_param_shapes = len(unflat_param_shapes)
assert (
num_unflat_params == num_unflat_param_shapes
), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
assert num_unflat_params == num_unflat_param_shapes, (
f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
)
# Check if these unflattened parameters have any optimizer state
has_state = [
@ -759,8 +761,7 @@ def _flatten_tensor_optim_state(
flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined]
assert flat_tensor.shape == flat_param_shape, (
f"tensor optim state: {flat_tensor.shape} "
f"flat parameter: {flat_param_shape}"
f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}"
)
return flat_tensor
@ -1065,9 +1066,9 @@ def _get_param_key_to_param(
"""
clean_fqn_to_curr_fqn: dict[str, str] = {}
if is_named_optimizer:
assert (
param_to_fqns is not None and flat_param_to_fqn is not None
), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
assert param_to_fqns is not None and flat_param_to_fqn is not None, (
"The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
)
assert model is not None
for key, _ in _named_parameters_with_duplicates(model):
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
@ -1150,9 +1151,9 @@ def _check_missing_keys_on_rank(
continue
param_key = optim_state_key_to_param_key[r0_optim_state_key]
if isinstance(param_key, int):
assert param_key >= 0 and param_key < len(
param_key_to_param
), "Check the `param_key_to_param` construction"
assert param_key >= 0 and param_key < len(param_key_to_param), (
"Check the `param_key_to_param` construction"
)
# We cannot use FSDPState.compute_device as this API is a global view.
device = _get_pg_default_device(group)
num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)

View File

@ -121,9 +121,9 @@ def _all_gather_dtensor(
"""
All gather a DTensor in its sharded dimension and return the local tensor.
"""
assert (
root_mesh == tensor.device_mesh
), "The device mesh of a tensor should be a root mesh."
assert root_mesh == tensor.device_mesh, (
"The device mesh of a tensor should be a root mesh."
)
placements = list(copy.deepcopy(tensor.placements))
# FSDP placements: [Shard(0)] -> [Replicate()]

View File

@ -466,9 +466,9 @@ def _local_pre_load_state_dict_hook(
)
return
load_tensor = state_dict[fqn]
assert isinstance(
load_tensor, ShardedTensor
), "Tensors in local_state_dict should be ShardedTensor."
assert isinstance(load_tensor, ShardedTensor), (
"Tensors in local_state_dict should be ShardedTensor."
)
# Convert the ShardedTensor to a Tensor.
flat_param = _module_handle(fsdp_state, module).flat_param

View File

@ -143,9 +143,9 @@ class _ExecOrderTracer:
named_params = list(module.named_parameters())
curr_module = exec_info.curr_module
if named_params:
assert (
curr_module in exec_info.module_to_param_usage_infos
), "The current module should have already been processed by a patched `call_module`"
assert curr_module in exec_info.module_to_param_usage_infos, (
"The current module should have already been processed by a patched `call_module`"
)
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
_ParamUsageInfo(module, named_params)
)

View File

@ -185,9 +185,9 @@ def _unshard_fsdp_state_params(
yield
return
assert (
handle._training_state == HandleTrainingState.IDLE
), f"Expects the handle training to be IDLE but got {handle._training_state}"
assert handle._training_state == HandleTrainingState.IDLE, (
f"Expects the handle training to be IDLE but got {handle._training_state}"
)
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS

View File

@ -306,16 +306,21 @@ class FullStateDictConfig(StateDictConfig):
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>> state = fsdp.state_dict()
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>> state_dict = torch.load("my_checkpoint.pt")
>>> model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> fsdp = FSDP(
... model,
... device_id=torch.cuda.current_device(),
... auto_wrap_policy=...,
... sync_module_states=True,
... )
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
Attributes:

View File

@ -723,9 +723,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
if prev_state_dict_type is None:
prev_state_dict_type = submodule._state_dict_type
else:
assert (
prev_state_dict_type == submodule._state_dict_type
), "All FSDP modules should have the same state_dict_type."
assert prev_state_dict_type == submodule._state_dict_type, (
"All FSDP modules should have the same state_dict_type."
)
if prev_state_dict_config is None:
prev_state_dict_config = submodule._state_dict_config
else:
@ -738,7 +738,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
assert isinstance(
submodule._optim_state_dict_config,
type(prev_optim_state_dict_config),
), "All FSDP modules must have the same type of optim_state_dict_config."
), (
"All FSDP modules must have the same type of optim_state_dict_config."
)
submodule._state_dict_type = state_dict_type
submodule._state_dict_config = state_dict_config
@ -2153,9 +2155,9 @@ def _get_param_to_fqn(
"""
param_to_param_names = _get_param_to_fqns(model)
for param_names in param_to_param_names.values():
assert (
len(param_names) > 0
), "`_get_param_to_fqns()` should not construct empty lists"
assert len(param_names) > 0, (
"`_get_param_to_fqns()` should not construct empty lists"
)
if len(param_names) > 1:
raise RuntimeError(
"Each parameter should only map to one parameter name but got "

View File

@ -112,20 +112,16 @@ class ShardedGradScaler(GradScaler):
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
def scale(self, outputs: torch.Tensor) -> torch.Tensor: ...
@overload
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]:
...
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ...
@overload
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
...
def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ...
def scale(
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
@ -323,8 +319,10 @@ class ShardedGradScaler(GradScaler):
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
reason = (
"new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
torch.FloatTensor with requires_grad=False."
)
assert new_scale.device.type == self._device, reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason

View File

@ -61,9 +61,9 @@ def _post_order_apply(
"Non-root modules should have their module name set but got "
f"an empty module name for {module}"
)
assert isinstance(
optional_module, nn.Module
), f"fn should return None or an nn.Module but got {optional_module}"
assert isinstance(optional_module, nn.Module), (
f"fn should return None or an nn.Module but got {optional_module}"
)
setattr(parent_module, module_name, optional_module)
_post_order_apply_inner(root_module, "", None)
@ -575,9 +575,9 @@ class _ConfigAutoWrap:
)
_ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
assert (
"wrapper_cls" in kwargs.keys()
), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
assert "wrapper_cls" in kwargs.keys(), (
"Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
)
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
# Save the rest.

View File

@ -183,8 +183,7 @@ def parse_args(args):
def launch(args):
if args.no_python and not args.use_env:
raise ValueError(
"When using the '--no-python' flag,"
" you must also set the '--use-env' flag."
"When using the '--no-python' flag, you must also set the '--use-env' flag."
)
run(args)

View File

@ -39,7 +39,10 @@ _REMOTE_MODULE_PICKLED_ATTRIBUTES = (
"module_rref",
)
_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc]
_SerializedRemoteModule = collections.namedtuple( # type: ignore[misc]
"_SerializedRemoteModule",
_REMOTE_MODULE_PICKLED_ATTRIBUTES,
)
# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES

View File

@ -26,15 +26,15 @@ sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
def get_arg_return_types_from_interface(module_interface):
assert getattr(
module_interface, "__torch_script_interface__", False
), "Expect a TorchScript class interface decorated by @torch.jit.interface."
assert getattr(module_interface, "__torch_script_interface__", False), (
"Expect a TorchScript class interface decorated by @torch.jit.interface."
)
qualified_name = torch._jit_internal._qualified_name(module_interface)
cu = torch.jit._state._python_cu
module_interface_c = cu.get_interface(qualified_name)
assert (
"forward" in module_interface_c.getMethodNames()
), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
assert "forward" in module_interface_c.getMethodNames(), (
f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
)
method_schema = module_interface_c.getMethod("forward")
arg_str_list = []

View File

@ -5,6 +5,7 @@ optimizer locally on the workers where the parameters live. The distributed
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
apply the gradients on each worker.
"""
import warnings
import torch

View File

@ -44,10 +44,10 @@ def _apply_optimizer_in_backward(
param_1 = next(params_generator)
remainder_params = list(params_generator)
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02})
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04})
model(...).sum().backward() # after backward, parameters will already
model(...).sum().backward() # after backward, parameters will already
# have their registered optimizer(s) applied.
"""
@ -111,7 +111,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Opt
List[torch.optim.Optimizer]: the in-backward optimizers.
Example::
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
optims = _get_optimizers_in_backward(model)
"""
optims: list[torch.optim.Optimizer] = []

View File

@ -147,12 +147,10 @@ class _NamedOptimizer(optim.Optimizer):
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
@overload
def step(self, closure: None = ...) -> None:
...
def step(self, closure: None = ...) -> None: ...
@overload
def step(self, closure: Callable[[], float]) -> float:
...
def step(self, closure: Callable[[], float]) -> float: ...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
r"""Zero Redundancy Optimizer."""
import collections
import copy
import enum
@ -262,9 +263,9 @@ class _OverlapInfo:
meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
in preparation for the next iteration.
"""
assert (
len(self.broadcast_handles) == self.num_bucket_assignments
), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
assert len(self.broadcast_handles) == self.num_bucket_assignments, (
f"Missing at least one broadcast handle on rank {dist.get_rank()}"
)
_ = [x.wait() for x in self.broadcast_handles]
self.broadcast_handles.clear()
@ -909,9 +910,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
params_per_rank = overlap_info.params_per_rank
offsets = overlap_info.offsets
self._bucket_assignments_per_rank_cache[assigned_rank][
bucket_index
] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = (
_DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
)
if self.global_rank == assigned_rank:
offsets[bucket_index] = len(params_per_rank[assigned_rank])
params_per_rank[assigned_rank].extend(bucket_params)
@ -927,9 +928,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
mapping bucket indices to :class:`_DDPBucketAssignment` s for each
rank.
"""
assert (
self._overlap_with_ddp
), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
assert self._overlap_with_ddp, (
"`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
)
if len(self._bucket_assignments_per_rank_cache) > 0:
return self._bucket_assignments_per_rank_cache
@ -1076,9 +1077,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
"Specifying `gradients` should not "
"be used when `overlap_with_ddp=False`"
)
assert (
closure is None
), "`closure` is not supported when using a local functional optimizer"
assert closure is None, (
"`closure` is not supported when using a local functional optimizer"
)
loss = self.optim.step(gradients=gradients)
# Sync any updated attributes in the local optimizer to the exposed
@ -1221,9 +1222,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
for rank, local_state_dict in enumerate(self._all_state_dicts):
local_param_groups = local_state_dict["param_groups"]
global_param_groups = self._partition_parameters()[rank]
assert len(local_param_groups) == len(
global_param_groups
), "Mismatch between number of local and global parameter groups"
assert len(local_param_groups) == len(global_param_groups), (
"Mismatch between number of local and global parameter groups"
)
for local_param_group, global_param_group in zip(
local_param_groups, global_param_groups
@ -1233,9 +1234,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
local_param_indices = local_param_group["params"]
global_params = global_param_group["params"]
assert len(local_param_indices) == len(
global_params
), "Mismatch between number of local and global parameters in parameter group"
assert len(local_param_indices) == len(global_params), (
"Mismatch between number of local and global parameters in parameter group"
)
for local_param_index, global_param in zip(
local_param_indices, global_params
):
@ -1268,9 +1269,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
dst_param_groups (list[dict]): parameter groups giving the
attribute settings to set.
"""
assert len(src_param_groups) == len(
dst_param_groups
), "Mismatch between number of source and destination parameter groups"
assert len(src_param_groups) == len(dst_param_groups), (
"Mismatch between number of source and destination parameter groups"
)
for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
# Sync all attributes except the parameters
for attr in filter(lambda x: x != "params", src_param_group.keys()):
@ -1479,9 +1480,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
The local optimizer is saved in ``self.optim``.
"""
assert (
self._optim_constructor is not None
), "The local optimizer class has not been set"
assert self._optim_constructor is not None, (
"The local optimizer class has not been set"
)
param_groups = self._partition_parameters()[self.rank]
# `overlap_with_ddp=True` requires a local functional optimizer
@ -1508,7 +1509,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
"error due to an empty parameter list",
self._optim_constructor,
)
self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef]
self.optim: Any = self._optim_constructor(
params, **self._optim_defaults
) # type: ignore[no-redef]
# Log information about the DDP and ZeRO bucketing
if dist.get_debug_level() != dist.DebugLevel.OFF:
@ -1531,7 +1534,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
else:
# NOTE: Passing `param_groups` into the local optimizer constructor
# bypasses the empty parameter list check
self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef]
self.optim: Optimizer = self._optim_constructor(
param_groups, **self._optim_defaults
) # type: ignore[no-redef]
# TODO: Manually add `self.param_groups` if using a functional
# optimizer; remove this if/when the functional optimizers support

View File

@ -123,12 +123,11 @@ def _insert_stage_symbolic_backward(
# getitem calls. If we have a target other than getitem in this
# (forward-only) code, there is a bug.
assert node.target == operator.getitem, (
"Found non-getitem call in forward pass. "
"Please report a bug to PiPPy"
"Found non-getitem call in forward pass. Please report a bug to PiPPy"
)
assert len(node.args) == 2, (
"Found malformed getitem call. Please report a bug to PiPPy"
)
assert (
len(node.args) == 2
), "Found malformed getitem call. Please report a bug to PiPPy"
indexed_value, node_idx = tuple(node.args)
# indexed_value is a collection that we are indexing into. It could
@ -249,8 +248,8 @@ class LossWrapper(torch.nn.Module):
targets value into the loss function, and get and return the loss value, which will
be backpropagated by PiPPy. The above class would then be instantiated like::
model = ... # instantiate the model
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
model = ... # instantiate the model
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
wrapper = MyModelWrapper(model, loss_fn)
pipe = Pipe.from_tracing(wrapper, ...)
@ -818,9 +817,9 @@ class Pipe(torch.nn.Module):
# Get submodule
callee = root.get_submodule(callee_name)
assert not hasattr(
callee, param_fqn
), f"Module {callee_name} already has a parameter named {param_fqn}"
assert not hasattr(callee, param_fqn), (
f"Module {callee_name} already has a parameter named {param_fqn}"
)
# Assign the parameter to the submodule
if is_buffer:
@ -979,7 +978,7 @@ class Pipe(torch.nn.Module):
else:
logger.debug("Pipeline is in inference mode, backward pass not generated")
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
logger.debug(f"Full pipe model:\n{split}") # noqa: G004
return Pipe(
split,
@ -1184,7 +1183,7 @@ def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
except AttributeError as e:
raise AttributeError(
f"Specified target {qualname} referenced "
f'nonexistent module {".".join(atoms[: i + 1])}'
f"nonexistent module {'.'.join(atoms[: i + 1])}"
) from e
mod_to_wrap = getattr(predecessor_module, atoms[-1])

View File

@ -306,17 +306,17 @@ def stage_backward(
if isinstance(output_val, torch.Tensor):
if not output_val.requires_grad and output_val.grad_fn is None:
return
assert isinstance(
grad_val, (torch.Tensor, type(None))
), f"Expected Tensor or None gradient but got {type(grad_val)}"
assert isinstance(grad_val, (torch.Tensor, type(None))), (
f"Expected Tensor or None gradient but got {type(grad_val)}"
)
stage_output_tensors.append(output_val)
output_grad_tensors.append(grad_val)
elif isinstance(output_val, (tuple, list)):
if grad_val is None:
return
assert isinstance(
grad_val, (tuple, list)
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
assert isinstance(grad_val, (tuple, list)), (
f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
)
assert len(output_val) == len(grad_val)
for ov, gv in zip(output_val, grad_val):
extract_tensors_with_grads(
@ -350,7 +350,8 @@ def stage_backward(
)
torch.autograd.backward(
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
stage_output_tensors,
grad_tensors=output_grad_tensors, # type: ignore[arg-type]
)
# Extract gradients wrt the input values

View File

@ -140,9 +140,9 @@ def _shard_dict_of_args(
real_num_chunks = num_chunks
first_tensor = True
assert len(args_dict) == len(
args_chunk_spec
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
assert len(args_dict) == len(args_chunk_spec), (
f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
)
for arg_key, arg in args_dict.items():
flat, spec = tree_flatten(arg)

View File

@ -706,7 +706,9 @@ class Schedule1F1B(PipelineScheduleSingle):
recv_work.wait()
# Compute
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
) # type: ignore[index]
# Clear previous chunk's forward sends (hopefully they have well
# finished, otherwise, we are heavily communication bound, in which
@ -762,7 +764,9 @@ class Schedule1F1B(PipelineScheduleSingle):
fuse_work.wait()
# Now do the fwd
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
) # type: ignore[index]
# Compute loss
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
@ -992,9 +996,9 @@ def _add_send_recv(
progress = False
# go in order of ranks even if dict keys aren't ordered
for rank in sorted(compute_actions):
assert (
len(compute_actions[rank]) > 0
), f"{rank=}, {len(compute_actions[rank])=}"
assert len(compute_actions[rank]) > 0, (
f"{rank=}, {len(compute_actions[rank])=}"
)
action = compute_actions[rank][0]
if not _ready_to_schedule(action, prev_actions[rank]):
@ -1026,9 +1030,9 @@ def _validate_schedule(
num_stages: int,
num_microbatches: int,
) -> dict[int, int]:
assert (
len(actions) == pp_group_size
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
assert len(actions) == pp_group_size, (
f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
)
for rank in range(pp_group_size):
assert rank in actions, f"Schedule is missing actions for rank {rank}"
@ -1048,36 +1052,36 @@ def _validate_schedule(
for action in actions[rank]:
if action is None:
continue
assert isinstance(
action, _Action
), f"Got an invalid action: {action}, expected instance of _Action"
assert isinstance(action, _Action), (
f"Got an invalid action: {action}, expected instance of _Action"
)
s_id = action.stage_index
ctype = action.computation_type
mb_id = action.microbatch_index
if ctype == F:
stage_actions[s_id][F].add(mb_id)
elif ctype == B:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
assert mb_id in stage_actions[s_id][F], (
f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
)
stage_actions[s_id][B].add(mb_id)
elif ctype == I:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
assert mb_id in stage_actions[s_id][F], (
f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
)
stage_actions[s_id][I].add(mb_id)
elif ctype == W:
assert (
mb_id in stage_actions[s_id][I]
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"
assert mb_id in stage_actions[s_id][I], (
f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"
)
stage_actions[s_id][W].add(mb_id)
if s_id not in stage_index_to_rank_mapping:
stage_index_to_rank_mapping[s_id] = rank
else:
existing_rank = stage_index_to_rank_mapping[s_id]
assert (
rank == existing_rank
), f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
assert rank == existing_rank, (
f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
)
for s_id in stage_actions:
f_mb = len(stage_actions[s_id][F])
@ -1085,14 +1089,14 @@ def _validate_schedule(
i_mb = len(stage_actions[s_id][I])
w_mb = len(stage_actions[s_id][W])
assert (
f_mb == num_microbatches
), f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
assert f_mb == num_microbatches, (
f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
)
assert (
b_mb + (i_mb + w_mb) // 2 == num_microbatches
), f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, (
f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
but got B={b_mb}, I={i_mb}, W={w_mb}"
)
return stage_index_to_rank_mapping
@ -1289,9 +1293,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
computation_type = action.computation_type
mb_index = action.microbatch_index
stage_index = action.stage_index
assert (
mb_index is not None
), "All currently supported action types require valid microbatch_index"
assert mb_index is not None, (
"All currently supported action types require valid microbatch_index"
)
if computation_type == _ComputationType.FORWARD:
# perform forward computation
stage = stage_index_to_stage[stage_index]
@ -1362,9 +1366,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
computation_type = prev_rank_action.computation_type
mb_index = prev_rank_action.microbatch_index
stage_index = prev_rank_action.stage_index
assert (
mb_index is not None
), "All currently supported action types require valid microbatch_index"
assert mb_index is not None, (
"All currently supported action types require valid microbatch_index"
)
# Only handle sends for the forward from a previous rank
if computation_type == _ComputationType.FORWARD:
# If not the last stage, then receive fwd activations
@ -1393,9 +1397,9 @@ class PipelineScheduleMulti(_PipelineSchedule):
computation_type = next_rank_action.computation_type
mb_index = next_rank_action.microbatch_index
stage_index = next_rank_action.stage_index
assert (
mb_index is not None
), "All currently supported action types require valid microbatch_index"
assert mb_index is not None, (
"All currently supported action types require valid microbatch_index"
)
# Only handle receives for the backwards from a next rank
if computation_type in (FORWARD, BACKWARD_WEIGHT):
# Next rank doing forward or weight update has no influence for the current rank backward recv
@ -1503,9 +1507,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
# that it does not exist if it was created from a compute_comms schedule.
assert (
self.pipeline_order_with_comms is not None
), "Must initialize compute_comms schedule before dump_csv"
assert self.pipeline_order_with_comms is not None, (
"Must initialize compute_comms schedule before dump_csv"
)
with open(filename, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
for rank in self.pipeline_order_with_comms:
@ -1541,9 +1545,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
stage.stage_index: stage for stage in self._stages
}
assert (
self.pipeline_order_with_comms is not None
), "Must call _load_actions() before calling _step_microbatches()"
assert self.pipeline_order_with_comms is not None, (
"Must call _load_actions() before calling _step_microbatches()"
)
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
bwd_recv_ops: dict[tuple[int, int], Work] = {}
@ -1562,9 +1566,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
unshard_ops[stage_idx].wait()
del unshard_ops[stage_idx]
unsharded_stages.add(stage_idx)
assert (
stage_idx in unsharded_stages
), f"Attempted to compute on sharded {stage_idx=}"
assert stage_idx in unsharded_stages, (
f"Attempted to compute on sharded {stage_idx=}"
)
# count either full_backward or backward_weight together, to determine when to sync DP grads
backward_counter: Counter[int] = Counter()
@ -1606,7 +1610,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
assert (
stage_idx,
mb_index,
) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward"
) not in fwd_recv_ops, (
"Recv twice for {stage_idx=} {mb_index=} without executing forward"
)
fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
stage.get_fwd_recv_ops(mb_index)
)
@ -1614,7 +1620,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
assert (
stage_idx,
mb_index,
) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward"
) not in bwd_recv_ops, (
"Recv twice for {stage_idx=} {mb_index=} without executing backward"
)
bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
stage.get_bwd_recv_ops(mb_index)
)
@ -1627,12 +1635,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator]
elif comp_type == RESHARD:
if stage_uses_fsdp:
assert (
stage_idx in unsharded_stages
), f"Resharding {stage_idx=} without unsharding"
assert (
stage_idx not in unshard_ops
), f"Resharding {stage_idx=} before finishing unshard"
assert stage_idx in unsharded_stages, (
f"Resharding {stage_idx=} without unsharding"
)
assert stage_idx not in unshard_ops, (
f"Resharding {stage_idx=} before finishing unshard"
)
stage.submod.reshard() # type: ignore[operator]
elif comp_type == FORWARD:
if stage_uses_fsdp:
@ -1739,7 +1747,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
)
# TODO(whc) what is the best practice for printing a multiline log?
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
print(_format_pipeline_order(self.pipeline_order_with_comms, error_step_number=time_step)) # type: ignore[arg-type]
print(
_format_pipeline_order(
self.pipeline_order_with_comms, # type: ignore[arg-type]
error_step_number=time_step,
)
)
raise e
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them

View File

@ -243,16 +243,16 @@ class _PipelineStageBase(ABC):
configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
which could show up as hangs, silent corruption, or other errors.
"""
assert (
self._outputs_meta is None
), "Attempting to reconfigure output_meta, which is not supported"
assert self._outputs_meta is None, (
"Attempting to reconfigure output_meta, which is not supported"
)
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
def get_outputs_meta(self) -> tuple[torch.Tensor, ...]:
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
assert (
self._outputs_meta is not None
), "Attempted to get_outputs_meta() without configuring output meta"
assert self._outputs_meta is not None, (
"Attempted to get_outputs_meta() without configuring output meta"
)
return self._outputs_meta
def _create_grad_send_info(
@ -358,12 +358,12 @@ class _PipelineStageBase(ABC):
prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs)
for info, tensor in zip(recv_infos, prev_stage_outputs):
assert isinstance(
tensor, torch.Tensor
), f"expected tensor values as outputs from prev stage, got {type(tensor)}"
assert isinstance(
info, _RecvInfo
), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo"
assert isinstance(tensor, torch.Tensor), (
f"expected tensor values as outputs from prev stage, got {type(tensor)}"
)
assert isinstance(info, _RecvInfo), (
"set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo"
)
# We don't need to do a data copy here, since we can directly pass the activation tensor reference from
# one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve
@ -376,9 +376,9 @@ class _PipelineStageBase(ABC):
"""
Returns the input grad tensors for this stage, which correspond to the stage inputs during forward.
"""
assert (
self.has_backward
), "can't steal_bwd_input if this stage doesn't have backward"
assert self.has_backward, (
"can't steal_bwd_input if this stage doesn't have backward"
)
assert not self.is_first, "can't get bwd output if this stage is first"
self._check_chunk_id(mb_index)
@ -391,22 +391,22 @@ class _PipelineStageBase(ABC):
Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
Does not detach or set '_requires_grad'.
"""
assert isinstance(
next_stage_bwd_outputs, tuple
), f"Expected tuple, got {type(next_stage_bwd_outputs)}"
assert isinstance(next_stage_bwd_outputs, tuple), (
f"Expected tuple, got {type(next_stage_bwd_outputs)}"
)
assert (
self.has_backward
), "can't set bwd input if this stage doesn't have backward"
assert self.has_backward, (
"can't set bwd input if this stage doesn't have backward"
)
assert not self.is_last, "can't set bwd input if this stage is last"
recv_infos = self.grad_recv_info[mb_index]
for info, tensor in zip(recv_infos, next_stage_bwd_outputs):
assert isinstance(
tensor, torch.Tensor
), f"expected tensor values as outputs from prev stage, got {type(tensor)}"
assert isinstance(
info, _RecvInfo
), f"Expected a recv info, got {type(info)}"
assert isinstance(tensor, torch.Tensor), (
f"expected tensor values as outputs from prev stage, got {type(tensor)}"
)
assert isinstance(info, _RecvInfo), (
f"Expected a recv info, got {type(info)}"
)
info.buffer = tensor
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
@ -1053,9 +1053,9 @@ class _PipelineStage(_PipelineStageBase):
# If the input is a getitem, we need to go deeper
arg_node = arg_node.args[0]
assert (
arg_node.op == "call_module"
), f"Expecting call_module, got {arg_node.op}"
assert arg_node.op == "call_module", (
f"Expecting call_module, got {arg_node.op}"
)
src_stage = self.get_stage_index_of_submod(arg_node.name)
# Create a receive buffer for this placeholder
@ -1081,7 +1081,8 @@ class _PipelineStage(_PipelineStageBase):
args_recv_info: list[InputInfo] = []
# Filter out placeholder nodes from `self.submod` (a GraphModule)
placeholders = filter( # type: ignore[var-annotated]
lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr]
lambda node: node.op == "placeholder", # type: ignore[arg-type]
self.submod.graph.nodes, # type: ignore[arg-type,union-attr]
)
# `placeholders` are nodes internal to submod.
# `self.node.args` are dependency nodes in the outer graph.
@ -1300,9 +1301,9 @@ class PipelineStage(_PipelineStageBase):
raise RuntimeError(
"Failed to perform pipeline shape inference- are your inputs on the same device as your module?"
) from e
assert (
output_args is not None
), "If passing input_args, also pass output_args to override shape inference"
assert output_args is not None, (
"If passing input_args, also pass output_args to override shape inference"
)
self._configure_outputs_meta(
(output_args,) if isinstance(output_args, torch.Tensor) else output_args
)
@ -1346,9 +1347,9 @@ class PipelineStage(_PipelineStageBase):
)
args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args)
else:
assert (
len(args) == 0
), "Can't supply input args for shape inference on non-first stage"
assert len(args) == 0, (
"Can't supply input args for shape inference on non-first stage"
)
objects = [None]
logger.debug(
"Shape inference: stage %s receiving from stage %s",

View File

@ -80,9 +80,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
world_size = world_size_opt
if rank != -1 or world_size != -1 or world_size_opt is None:
query_dict = _query_to_dict(result.query)
assert (
"rank" not in query_dict and "world_size" not in query_dict
), f"The url: {url} has node-specific arguments(rank, world_size) already."
assert "rank" not in query_dict and "world_size" not in query_dict, (
f"The url: {url} has node-specific arguments(rank, world_size) already."
)
if rank != -1:
query_dict["rank"] = str(rank)
if world_size != -1 or world_size_opt is None:

View File

@ -137,13 +137,13 @@ def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
with _all_gather_dict_lock:
if not worker_names:
worker_names = _ALL_WORKER_NAMES
assert (
worker_name in worker_names
), f"{worker_name} is not expected by leader."
assert worker_name in worker_names, (
f"{worker_name} is not expected by leader."
)
states = _all_gather_sequence_id_to_states[sequence_id]
assert (
worker_name not in states.gathered_objects
), f"{worker_name} reported intent sequence id {sequence_id} twice. "
assert worker_name not in states.gathered_objects, (
f"{worker_name} reported intent sequence id {sequence_id} twice. "
)
states.gathered_objects[worker_name] = obj
if worker_names == set(states.gathered_objects.keys()):
states.proceed_signal.set()
@ -153,9 +153,9 @@ def _broadcast_to_followers(sequence_id, objects_map):
with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states[sequence_id]
assert (
not states.proceed_signal.is_set()
), f"Termination signal sequence id {sequence_id} got set twice."
assert not states.proceed_signal.is_set(), (
f"Termination signal sequence id {sequence_id} got set twice."
)
states.gathered_objects = objects_map
states.proceed_signal.set()
@ -202,9 +202,9 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
function blocks until all workers have received the gathered results.
"""
if not worker_names:
assert (
_ALL_WORKER_NAMES is not None
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
assert _ALL_WORKER_NAMES is not None, (
"`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
)
worker_names = _ALL_WORKER_NAMES
leader_name = min(worker_names)
@ -930,8 +930,7 @@ def _get_should_profile():
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
return (
torch.autograd._profiler_enabled()
and torch._C._autograd._profiler_type()
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
)

View File

@ -23,7 +23,7 @@ def _to_device(device: DeviceType) -> torch.device:
def _to_device_map(
device_map: dict[DeviceType, DeviceType]
device_map: dict[DeviceType, DeviceType],
) -> dict[torch.device, torch.device]:
full_device_map: dict[torch.device, torch.device] = {}
reverse_map: dict[torch.device, torch.device] = {}
@ -127,7 +127,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
>>> options = TensorPipeRpcBackendOptions(
>>> num_worker_threads=8,
>>> device_maps={"worker1": {0: 1}}
>>> # maps worker0's cuda:0 to worker1's cuda:1
>>> # maps worker0's cuda:0 to worker1's cuda:1
>>> )
>>> options.set_device_map("worker1", {1: 2})
>>> # maps worker0's cuda:1 to worker1's cuda:2

View File

@ -63,10 +63,14 @@ class _server_process_global_profile(profile):
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> x, y = torch.tensor(1), torch.tensor(2)
>>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
>>> outer_profile_rref = rpc.remote(
... dst_worker_name, rpc._server_process_global_profile
... )
>>> outer_profile_rref.rpc_sync().__enter__()
>>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
>>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
>>> inner_profile_rref = rpc.remote(
... dst_worker_name, rpc._server_process_global_profile
... )
>>> inner_profile_rref.rpc_sync().__enter__()
>>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
>>> inner_profile_rref.rpc_sync().__exit__(None, None, None)

View File

@ -289,9 +289,9 @@ Important Notices
::
>>> # xdoctest: +SKIP("stub")
>>> import torch.distributed as dist
>>> dist.init_process_group(backend="gloo|nccl")
>>> # xdoctest: +SKIP("stub")
>>> import torch.distributed as dist
>>> dist.init_process_group(backend="gloo|nccl")
3. In your training program, you can either use regular distributed functions
or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
@ -302,9 +302,9 @@ Important Notices
::
local_rank = int(os.environ["LOCAL_RANK"])
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank
)
Please ensure that ``device_ids`` argument is set to be the only GPU device id
that your code will be operating on. This is generally the local rank of the
@ -331,17 +331,18 @@ utility
::
def main():
load_checkpoint(checkpoint_path)
initialize()
train()
def main():
load_checkpoint(checkpoint_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_checkpoint(checkpoint_path)
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_checkpoint(checkpoint_path)
9. (Recommended) On worker errors, this tool will summarize the details of the error
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
@ -353,17 +354,19 @@ utility
::
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.multiprocessing.errors import record
@record
def main():
# do train
pass
if __name__ == "__main__":
main()
@record
def main():
# do train
pass
if __name__ == "__main__":
main()
""" # noqa: E501
import os
import sys
import uuid

View File

@ -297,9 +297,9 @@ class DTensor(torch.Tensor):
@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
assert (
flatten_spec is not None
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
assert flatten_spec is not None, (
"Expecting spec to be not None from `__tensor_flatten__` return value!"
)
local_tensor = inner_tensors["_local_tensor"]
spec, requires_grad = flatten_spec
unflatten_tensor_meta = TensorMeta(
@ -694,9 +694,7 @@ def distribute_tensor(
xla_distribute_tensor,
)
return xla_distribute_tensor(
tensor, device_mesh, placements
) # type:ignore[return-value]
return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value]
except ImportError as e:
msg = "To use DTensor API with xla, you must install the torch_xla package!"
raise ImportError(msg) from e
@ -930,7 +928,9 @@ def distribute_module(
FutureWarning,
stacklevel=2,
)
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
module.register_forward_pre_hook(
lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg]
)
elif num_args == 3:
# input_fn takes in module, inputs, device mesh
module.register_forward_pre_hook(
@ -990,9 +990,9 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
# check device_mesh againts placements
assert device_mesh.ndim == len(
placements
), "mesh dimension does not match the length of placements"
assert device_mesh.ndim == len(placements), (
"mesh dimension does not match the length of placements"
)
assert kwargs["layout"] == torch.strided, "layout value not supported!"
torch_stride = torch._prims_common.make_contiguous_strides_for(size)

View File

@ -75,7 +75,8 @@ def found_inf_reduce_handler(
) -> None:
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
local_tensor_args = pytree.tree_unflatten(
cast(list[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type]
cast(list[object], op_info.local_args),
op_info.args_tree_spec, # type: ignore[arg-type]
)
local_tensor_args = cast(tuple[object, ...], local_tensor_args)
op_call(*local_tensor_args, **op_info.local_kwargs)
@ -200,8 +201,9 @@ class OpDispatcher:
# did not already construct one
random._rng_tracker = random.OffsetBasedRNGTracker(mesh)
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
torch.Tensor, local_tensor_args[0]
first_arg, first_local_arg = (
cast(dtensor.DTensor, args[0]),
cast(torch.Tensor, local_tensor_args[0]),
)
rng_context = (
random._rng_tracker._distribute_region(first_arg._spec)
@ -422,18 +424,18 @@ class OpDispatcher:
def wrap(res: object, spec: OutputSpecType) -> object:
if isinstance(res, torch.Tensor):
if spec is not None:
assert isinstance(
spec, DTensorSpec
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
assert isinstance(spec, DTensorSpec), (
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
)
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
else:
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
assert res.ndim == 0, "output tensor should be scalar!"
return res
elif isinstance(res, (list, tuple)):
assert spec is not None and isinstance(
spec, (list, tuple)
), f"output spec does not match with output! Expected list/tuple, got {spec}."
assert spec is not None and isinstance(spec, (list, tuple)), (
f"output spec does not match with output! Expected list/tuple, got {spec}."
)
res_list = []
for e, s in zip(res, spec):
res_list.append(OpDispatcher.wrap(e, s))

View File

@ -152,9 +152,9 @@ class OpStrategy(StrategyType):
if isinstance(output_spec, DTensorSpec):
return output_spec.mesh.shape
else:
assert isinstance(
output_spec, tuple
), "found no DTensorSpec in the OpStrategy!"
assert isinstance(output_spec, tuple), (
"found no DTensorSpec in the OpStrategy!"
)
assert output_spec[0] is not None
return output_spec[0].mesh.shape

View File

@ -63,9 +63,9 @@ class EinsumDims:
if is_batch_dim:
batch_dims.append(dim_char)
else:
assert (
len(input_dims) == 2
), "free dimension only supported for two inputs!"
assert len(input_dims) == 2, (
"free dimension only supported for two inputs!"
)
lhs, rhs = input_dims
if dim_char in lhs:
lhs_out_only_dims.append(dim_char)

View File

@ -89,9 +89,9 @@ class _MaskPartial(Partial):
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert (
self.offset_shape is not None
), "offset_shape needs to be set for _MaskPartial"
assert self.offset_shape is not None, (
"offset_shape needs to be set for _MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
self.offset_shape[self.offset_dim],
num_chunks,

View File

@ -994,9 +994,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
)
output_specs_list.append(weight_out_spec if output_mask[1] else None)
else:
assert (
output_mask[1] is False
), "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
assert output_mask[1] is False, (
"output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
)
output_specs_list.append(None)
# arg: bias
@ -1020,9 +1020,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
)
output_specs_list.append(bias_out_spec if output_mask[2] else None)
else:
assert (
output_mask[2] is False
), "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
assert output_mask[2] is False, (
"output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
)
output_specs_list.append(None)
out_tuple_strategy.strategies.append(

View File

@ -155,9 +155,9 @@ def _scaled_mm_like_strategy(
assert isinstance(scale_mat2_strategy, OpStrategy)
# TODO: add support for these later
assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias"
assert (
scale_result_strategy is None
), "_scaled_mm on DTensors doesn't support scale_result"
assert scale_result_strategy is None, (
"_scaled_mm on DTensors doesn't support scale_result"
)
# generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
# filter out invalid strategies and associate costs

View File

@ -445,9 +445,9 @@ def pointwise_strategy(
followed_strategy = op_schema.args_schema[max_shards_strategy_index]
assert isinstance(
followed_strategy, OpStrategy
), f"no strategy to follow for {op_schema}!"
assert isinstance(followed_strategy, OpStrategy), (
f"no strategy to follow for {op_schema}!"
)
return common_pointwise_strategy(
mesh, op_schema.args_schema, followed_strategy, linearity
)

View File

@ -254,9 +254,9 @@ def dim_movedim(
def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
sizes = normalize_sizes(sizes)
assert (
len(sizes) >= ndim
), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
assert len(sizes) >= ndim, (
f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
)
pad = len(sizes) - ndim
return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
@ -275,9 +275,9 @@ def infer_size(total_size: int, sizes: Shape) -> Shape:
if infers:
size = -size
missing_size = total_size // size
assert (
total_size % size == 0
), f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
assert total_size % size == 0, (
f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
)
return tuple(s if s != -1 else missing_size for s in sizes)
assert size == total_size, f"sizes do not match {total_size} vs {size}"
return sizes
@ -538,9 +538,9 @@ def propagate_shape_and_sharding(
for size, shard in zip(mesh_sizes, input_src_placements):
if isinstance(shard, Shard) and shard.dim == in_dim:
submesh_size *= size
assert (
out_size % submesh_size == 0
), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
assert out_size % submesh_size == 0, (
f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
)
# we will only shard our first component of the split
return in_dim if cmd.split_id == 0 else None

View File

@ -45,7 +45,7 @@ def register_prop_rule(
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def wrapper(
impl: Callable[[OpSchema], OutputSharding]
impl: Callable[[OpSchema], OutputSharding],
) -> Callable[[OpSchema], OutputSharding]:
overloads = op if isinstance(op, list) else [op]
for overload in overloads:
@ -102,7 +102,7 @@ def register_op_strategy(
def as_list(
x: Union[list[object], object]
x: Union[list[object], object],
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,

View File

@ -231,9 +231,9 @@ def redistribute_local_tensor(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert (
current.is_shard()
), f"Current placement should be shard but found {current}"
assert current.is_shard(), (
f"Current placement should be shard but found {current}"
)
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(

View File

@ -487,9 +487,9 @@ class ShardingPropagator:
strategy_costs: list[float] = []
for strtg in strategy.strategies:
assert (
strtg.redistribute_cost is not None
), "must set redistribute cost each strategy!"
assert strtg.redistribute_cost is not None, (
"must set redistribute cost each strategy!"
)
redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
strategy_costs.append(redistribute_cost)

View File

@ -73,9 +73,9 @@ def compute_local_shape_and_global_offset(
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(
local_shape
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
@ -141,16 +141,15 @@ def compute_local_shape_and_global_offset(
if isinstance(placement, _StridedShard):
strided_part_seen[shard_dim] = True
shard_idx_stride_by_mesh_dim[shard_dim][
idx
] = num_shards_by_tensor_dim[shard_dim] // (
placement.split_factor * mesh_dim_size
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
num_shards_by_tensor_dim[shard_dim]
// (placement.split_factor * mesh_dim_size)
)
else:
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
shard_idx_stride_by_mesh_dim[shard_dim][
idx
] = num_shards_by_tensor_dim[shard_dim]
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
num_shards_by_tensor_dim[shard_dim]
)
shard_idx = [
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
@ -205,9 +204,9 @@ def compute_global_tensor_info(
)
shard_dim = shard_placement.dim
assert (
shard_dim < tensor.ndim
), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
assert shard_dim < tensor.ndim, (
f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
)
local_dim_size = tensor_shape[shard_dim]
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size

View File

@ -283,9 +283,9 @@ class CommDebugMode(TorchDispatchMode):
"module_type" in self.advanced_module_tracker.module_helper_dict[fqn]
and include_module_data
):
json_dict[
"module_type"
] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
json_dict["module_type"] = (
self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
)
if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
for (
@ -659,9 +659,9 @@ class CommDebugMode(TorchDispatchMode):
operation_dict["is_bw"] = self.advanced_module_tracker.is_bw
# tracks if the operation is part of activation checkpointing
operation_dict[
"is_activation_checkpointing"
] = self.advanced_module_tracker.activation_checkpointing
operation_dict["is_activation_checkpointing"] = (
self.advanced_module_tracker.activation_checkpointing
)
if any(t == DTensor for t in types):
for ele in args:

View File

@ -108,9 +108,9 @@ def _compute_local_shape_and_global_offset(
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(
local_shape
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,

View File

@ -2,6 +2,7 @@
To run the example, use the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing
"""
import argparse
import os
from typing import Callable, Union

View File

@ -6,6 +6,7 @@ with intermediate activations sharded across mutliple GPUs via DTensor
To run the example, use the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
"""
import os
import time

View File

@ -3,6 +3,7 @@
The following example demonstrates how to represent torchrec's embedding
sharding with the DTensor API.
"""
import argparse
import os
from functools import cached_property

View File

@ -253,22 +253,18 @@ class _AttentionOp(Protocol):
key: torch.Tensor,
value: torch.Tensor,
**kwargs: object,
) -> tuple[torch.Tensor, ...]:
...
) -> tuple[torch.Tensor, ...]: ...
class _RingRotater(ABC):
@abstractmethod
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None:
...
def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ...
@abstractmethod
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None:
...
def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ...
@abstractmethod
def next_buffer(self) -> torch.Tensor:
...
def next_buffer(self) -> torch.Tensor: ...
class _AllToAllRotater(_RingRotater):
@ -1097,15 +1093,13 @@ class _LoadBalancer(ABC):
@abstractmethod
def shard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor:
...
) -> torch.Tensor: ...
@classmethod
@abstractmethod
def unshard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor:
...
) -> torch.Tensor: ...
class _SequentialSharder(_LoadBalancer):
@ -1147,9 +1141,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
def shard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor:
assert (
cls.ROUND_ROBIN_CYCLE == 2
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
assert cls.ROUND_ROBIN_CYCLE == 2, (
"The current implementation only works if ROUND_ROBIN_CYCLE is 2."
)
cp_world_size = mesh.size()
cp_rank = mesh.get_local_rank()
assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0
@ -1163,9 +1157,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
def unshard(
cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor:
assert (
cls.ROUND_ROBIN_CYCLE == 2
), "The current implementation only works if ROUND_ROBIN_CYCLE is 2."
assert cls.ROUND_ROBIN_CYCLE == 2, (
"The current implementation only works if ROUND_ROBIN_CYCLE is 2."
)
buffer = buffer.contiguous()
cp_world_size = mesh.size()

View File

@ -113,9 +113,15 @@ def local_map(
>>> device_mesh=device_mesh,
>>> )
>>>
>>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
>>> W_dt = distribute_tensor(
... W, device_mesh, (col_wise)
... ) # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(
... X, device_mesh, (row_wise)
... ) # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(
... device_mesh, W_dt, X_dt
... ) # apply local_mm_allreduce_forward to DTensors
.. note:: This API is currently experimental and subject to change
"""
@ -151,9 +157,9 @@ def local_map(
)
if in_placements is not None:
spec = in_placements[idx]
assert (
spec is not None
), f"DTensor input {arg} expects placements but received {spec}!"
assert spec is not None, (
f"DTensor input {arg} expects placements but received {spec}!"
)
if not isinstance(spec, tuple):
spec = tuple(spec)
@ -208,17 +214,17 @@ def local_map(
)
for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor):
assert not isinstance(
out, DTensor
), f"torch.Tensor output expected but received {type(out)}: {out}"
assert not isinstance(out, DTensor), (
f"torch.Tensor output expected but received {type(out)}: {out}"
)
flat_dist_out.append(
DTensor.from_local(out, device_mesh, spec, run_check=False)
)
else:
assert (
spec is None
), f"Non-tensor output {out} expects None placements but received {spec}!"
assert spec is None, (
f"Non-tensor output {out} expects None placements but received {spec}!"
)
flat_dist_out.append(out)

View File

@ -188,9 +188,14 @@ def _mark_sharding(
"""
Mark the sharding strategy for each node in the graph module.
"""
placement_strategies: dict[
Node, PlacementStrategy
] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
placement_strategies: dict[Node, PlacementStrategy] = (
_mark_tensor_parallel_shardings(
gm,
graph_signature,
mesh,
parameter_placements,
)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
@ -202,9 +207,9 @@ def _mark_sharding(
elif node.op == "call_function":
if node.target == operator.getitem:
input_nodes = node.all_input_nodes
assert (
len(input_nodes) == 1
), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
assert len(input_nodes) == 1, (
f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
)
arg_strategy = placement_strategies[input_nodes[0]]
placement_strategies[node] = _create_placement_strategy(
node,

View File

@ -328,7 +328,9 @@ class DTensorExtensions(FSDPExtensions):
self.device_handle = device_handle
# we have to use the dynamo disable this way to disable dynamo as the decorater way would
# trigger build failure with torch deploy...
self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign]
self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign]
self.post_unflatten_transform
)
def pre_flatten_transform(
self,

View File

@ -64,9 +64,7 @@ def input_reshard(
return module
def _pack_hook_tp(
mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor
) -> Any: # noqa: D401
def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401
"""Hook function called after FWD to shard input."""
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
@ -84,9 +82,7 @@ def _pack_hook_tp(
return x
def _unpack_hook_tp(
mesh: DeviceMesh, input_reshard_dim: int, x: Any
) -> torch.Tensor: # noqa: D401
def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401
"""Hook function called before activation recomputing in BWD to restore input."""
if (
isinstance(x, DTensor)

View File

@ -38,8 +38,7 @@ class ParallelStyle(ABC):
src_data_rank: Optional[int] = 0
@abstractmethod
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
...
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...
class ColwiseParallel(ParallelStyle):
@ -467,19 +466,21 @@ class PrepareModuleInput(ParallelStyle):
)
self.use_local_output = use_local_output
if self.input_layouts is not None:
assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
assert len(self.input_layouts) == len(
self.desired_input_layouts
), "input_layouts and desired_input_layouts should have same length!"
assert self.desired_input_layouts is not None, (
"desired module inputs should not be None!"
)
assert len(self.input_layouts) == len(self.desired_input_layouts), (
"input_layouts and desired_input_layouts should have same length!"
)
self.with_kwargs = input_kwarg_layouts is not None
self.input_kwarg_layouts = input_kwarg_layouts or {}
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
if self.with_kwargs:
assert len(self.input_kwarg_layouts) == len(
self.desired_input_kwarg_layouts
), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
), (
"input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
)
def _prepare_input_arg(
self,
@ -494,9 +495,9 @@ class PrepareModuleInput(ParallelStyle):
# assert inp.placements[0] == input_layout
dt_inp = input
else:
assert isinstance(
input, torch.Tensor
), "expecting input to be a torch.Tensor!"
assert isinstance(input, torch.Tensor), (
"expecting input to be a torch.Tensor!"
)
dt_inp = DTensor.from_local(
input, mesh, (input_layout,), run_check=False
)
@ -517,9 +518,9 @@ class PrepareModuleInput(ParallelStyle):
if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!")
assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
assert self.desired_input_layouts is not None, (
"desired module inputs should not be None!"
)
for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts
):
@ -551,7 +552,9 @@ class PrepareModuleInput(ParallelStyle):
with_kwargs=True,
) # type: ignore[misc]
else:
module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg]
module.register_forward_pre_hook(
lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
) # type: ignore[misc, call-arg]
return module
@ -611,9 +614,9 @@ class PrepareModuleOutput(ParallelStyle):
else desired_output_layouts
)
self.use_local_output = use_local_output
assert len(self.output_layouts) == len(
self.desired_output_layouts
), "output_layouts and desired_output_layouts should have same length!"
assert len(self.output_layouts) == len(self.desired_output_layouts), (
"output_layouts and desired_output_layouts should have same length!"
)
def _prepare_out_fn(self, outputs, device_mesh):
prepared_outputs = []
@ -649,5 +652,7 @@ class PrepareModuleOutput(ParallelStyle):
return tuple(prepared_outputs)
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg]
module.register_forward_hook(
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
) # type: ignore[misc, call-arg]
return module

View File

@ -83,9 +83,9 @@ class Shard(Placement):
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
This is because collectives usually require equal size tensor inputs
"""
assert (
self.dim <= tensor.ndim
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
assert self.dim <= tensor.ndim, (
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
)
# chunk tensor over dimension `dim` into n slices
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
@ -468,9 +468,9 @@ class _StridedShard(Shard):
"""
TODO: currently _StridedShard does not support padding
"""
assert (
self.dim <= tensor.ndim
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
assert self.dim <= tensor.ndim, (
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
)
total_split = num_chunks * self.split_factor
assert tensor.size(self.dim) % total_split == 0, (

Some files were not shown because too many files have changed in this diff Show More