mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 06:44:55 +08:00
Compare commits
5 Commits
documentat
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d4e1101c2 | |||
| 9074ffeb40 | |||
| 08537982e6 | |||
| 816779ad01 | |||
| 69f8d844ba |
@ -73,7 +73,7 @@ if is_available():
|
||||
_DistributedPdb().set_trace()
|
||||
"""
|
||||
|
||||
def interaction(self, *args, **kwargs):
|
||||
def interaction(self, *args, **kwargs) -> None:
|
||||
_stdin = sys.stdin
|
||||
try:
|
||||
sys.stdin = open("/dev/stdin")
|
||||
@ -83,7 +83,7 @@ if is_available():
|
||||
|
||||
_breakpoint_cache: dict[int, typing.Any] = {}
|
||||
|
||||
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
|
||||
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600) -> None:
|
||||
"""
|
||||
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
|
||||
done with the breakpoint before continuing.
|
||||
|
||||
@ -16,7 +16,7 @@ _T = TypeVar("_T", covariant=True)
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def generate_state_key(string="__composable_api_state_key"):
|
||||
def generate_state_key(string: str = "__composable_api_state_key") -> str:
|
||||
return f"{string}_{str(uuid.uuid4())}"
|
||||
|
||||
|
||||
@ -183,7 +183,9 @@ def contract(
|
||||
f"Outputs: {num_new_modules} modules"
|
||||
)
|
||||
|
||||
def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
|
||||
def check_fqn(
|
||||
orig_fqns: list[str], new_fqns: list[str], check_key: str
|
||||
) -> None:
|
||||
if orig_fqns == new_fqns:
|
||||
return
|
||||
|
||||
|
||||
@ -27,8 +27,7 @@ except Exception:
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def is_torchdynamo_compiling(): # type: ignore[misc]
|
||||
return False
|
||||
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -316,7 +316,7 @@ class LocalIntNode:
|
||||
return ConstantIntNode(next(iter(local_ints.values())))
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, local_ints: dict[int, int]):
|
||||
def __init__(self, local_ints: dict[int, int]) -> None:
|
||||
self._local_ints = local_ints
|
||||
|
||||
def maybe_as_int(self) -> Optional[int]:
|
||||
@ -590,7 +590,7 @@ class LocalTensor(torch.Tensor):
|
||||
|
||||
@torch._disable_dynamo
|
||||
@mark_subclass_constructor_exportable_experimental # type: ignore[misc]
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor":
|
||||
@ -765,7 +765,7 @@ class LocalTensorMode(TorchDispatchMode):
|
||||
"""
|
||||
|
||||
# What ranks this local tensor mode is operating over
|
||||
def __init__(self, ranks: Union[int, frozenset[int]]):
|
||||
def __init__(self, ranks: Union[int, frozenset[int]]) -> None:
|
||||
if isinstance(ranks, int):
|
||||
# assume is world size
|
||||
self.ranks = frozenset(range(ranks))
|
||||
@ -1087,7 +1087,7 @@ class LocalRunnerMode:
|
||||
|
||||
def __init__(
|
||||
self, ranks: frozenset[int] | int, concurrency: int, fn: Callable[[int], None]
|
||||
):
|
||||
) -> None:
|
||||
if isinstance(ranks, int):
|
||||
ranks = frozenset(range(ranks))
|
||||
self._ranks = ranks
|
||||
|
||||
@ -80,7 +80,7 @@ def shard_parameter(
|
||||
sharding_spec: ShardingSpec,
|
||||
src_rank=0,
|
||||
process_group=None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
|
||||
module, it shards that parameter according to the provided
|
||||
@ -223,7 +223,9 @@ def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
|
||||
return module
|
||||
|
||||
|
||||
def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None):
|
||||
def shard_module(
|
||||
module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None
|
||||
) -> None:
|
||||
"""
|
||||
Shards a given module according to the provided sharding `plan`. This method
|
||||
first shards all the parameters according to the given sharding `plan`. Then if
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
def _basic_validation(op, args=(), kwargs=None):
|
||||
def _basic_validation(op, args=(), kwargs=None) -> None:
|
||||
"""
|
||||
Common validation across all ops go in here.
|
||||
"""
|
||||
@ -17,7 +17,7 @@ def _basic_validation(op, args=(), kwargs=None):
|
||||
# Validate types
|
||||
has_distributed_tensor = False
|
||||
|
||||
def is_distributed_tensor(e):
|
||||
def is_distributed_tensor(e) -> None:
|
||||
nonlocal has_distributed_tensor
|
||||
if isinstance(e, ShardedTensor):
|
||||
has_distributed_tensor = True
|
||||
@ -34,7 +34,7 @@ def _basic_validation(op, args=(), kwargs=None):
|
||||
# Validate all distributed tensors use the same PG.
|
||||
cur_pg: Optional[torch.distributed.ProcessGroup] = None
|
||||
|
||||
def validate_pg(e):
|
||||
def validate_pg(e) -> None:
|
||||
nonlocal cur_pg
|
||||
if isinstance(e, ShardedTensor):
|
||||
if cur_pg is not None and e._process_group is not cur_pg:
|
||||
@ -48,7 +48,7 @@ def _basic_validation(op, args=(), kwargs=None):
|
||||
pytree.tree_map_(validate_pg, kwargs)
|
||||
|
||||
|
||||
def _register_default_op(op, decorator):
|
||||
def _register_default_op(op, decorator) -> None:
|
||||
@decorator(op)
|
||||
def tensor_default_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
|
||||
@ -34,7 +34,7 @@ class ShardMetadata:
|
||||
shard_offsets: list[int],
|
||||
shard_sizes: list[int],
|
||||
placement: Optional[Union[str, _remote_device]] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.shard_offsets = shard_offsets
|
||||
self.shard_sizes = shard_sizes
|
||||
if isinstance(placement, str):
|
||||
|
||||
@ -11,7 +11,7 @@ and PartialTensor.
|
||||
"""
|
||||
|
||||
|
||||
def _register_op(op, func, op_table):
|
||||
def _register_op(op, func, op_table) -> None:
|
||||
"""
|
||||
Performs basic validation and registers the provided op in the given
|
||||
op_table.
|
||||
|
||||
@ -409,7 +409,7 @@ def init_from_local_shards(
|
||||
)
|
||||
|
||||
|
||||
def state_dict_hook(module, destination, prefix, local_metadata):
|
||||
def state_dict_hook(module, destination, prefix, local_metadata) -> None:
|
||||
"""
|
||||
Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
|
||||
registered to the Module using
|
||||
@ -432,7 +432,7 @@ def pre_load_state_dict_hook(
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Pre-load state dict hook to add ShardedTensor to the module.
|
||||
"""
|
||||
|
||||
@ -62,7 +62,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
|
||||
|
||||
def _register_sharded_op_on_local_shards(
|
||||
op, early_stop_func=None, extra_check=None, customized_func=None
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for ops which are performed on
|
||||
each shard of the sharded tensor such as elementwise op like
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import torch
|
||||
import torch.distributed._shard.sharded_tensor as sharded_tensor
|
||||
from torch.distributed._shard.sharded_tensor import _sharded_op_impl
|
||||
@ -135,7 +136,7 @@ tensor_like_creation_op_map = {
|
||||
|
||||
|
||||
# tensor ops that behave the same as the default tensor
|
||||
def register_tensor_creation_op(op):
|
||||
def register_tensor_creation_op(op) -> None:
|
||||
@_sharded_op_impl(op)
|
||||
def tensor_creation_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import _sharded_op_impl
|
||||
|
||||
@ -8,5 +10,7 @@ from torch.distributed._shard.sharded_tensor import _sharded_op_impl
|
||||
# the future behavior of overwriting the existing tensor
|
||||
# instead of doing in-place change using `.data = `.
|
||||
@_sharded_op_impl(torch._has_compatible_shallow_copy_type)
|
||||
def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None):
|
||||
def tensor_has_compatible_shallow_copy_type(
|
||||
types, args=(), kwargs=None, pg=None
|
||||
) -> Literal[False]:
|
||||
return False
|
||||
|
||||
@ -61,7 +61,7 @@ def st_is_meta(types, args=(), kwargs=None, pg=None):
|
||||
return args[0].local_tensor().is_meta
|
||||
|
||||
|
||||
def sharded_type_as_check(*args, **kwargs):
|
||||
def sharded_type_as_check(*args, **kwargs) -> None:
|
||||
"""
|
||||
Perform extra checks for the sharded_type_as op such as the input needs to
|
||||
be either a Tensor or ShardedTensor.
|
||||
|
||||
@ -91,7 +91,7 @@ def _flatten_tensor_size(size) -> torch.Size:
|
||||
return torch.Size(dims)
|
||||
|
||||
|
||||
def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
|
||||
def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True) -> None:
|
||||
if is_local:
|
||||
assert isinstance(ranks, int)
|
||||
if expected != actual:
|
||||
|
||||
@ -5,7 +5,9 @@ from typing import Optional
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
||||
|
||||
def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata):
|
||||
def _check_shard_metadata_pair_overlap(
|
||||
shard1: ShardMetadata, shard2: ShardMetadata
|
||||
) -> bool:
|
||||
"""
|
||||
Checks if two shards overlap.
|
||||
"""
|
||||
@ -70,7 +72,7 @@ def _find_1d_overlapping_shards(
|
||||
return None
|
||||
|
||||
|
||||
def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]):
|
||||
def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]) -> None:
|
||||
"""
|
||||
Ensures none of the shards overlap with each other.
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from torch.distributed.nn.functional import (
|
||||
)
|
||||
|
||||
|
||||
def _chunk_sharding_spec_check(spec, op):
|
||||
def _chunk_sharding_spec_check(spec, op) -> None:
|
||||
"""
|
||||
For the given op implementation check if the sharding spec is ChunkShardingSpec.
|
||||
"""
|
||||
@ -30,7 +30,7 @@ def _chunk_sharding_spec_check(spec, op):
|
||||
|
||||
def _register_sharded_op_on_local_tensor(
|
||||
op, early_stop_func=None, extra_check=None, customized_func=None
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for ops which are performed on
|
||||
the single local tensor of the sharded tensor such as op like
|
||||
|
||||
@ -128,7 +128,7 @@ def sharded_embedding(types, args, kwargs, pg):
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedding_param(args, kwargs):
|
||||
def _validate_embedding_param(args, kwargs) -> None:
|
||||
"""
|
||||
Validate input params of sharded embedding op.
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ def sharded_embedding_bag(types, args, kwargs, pg):
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedding_bag_param(args, kwargs):
|
||||
def _validate_embedding_bag_param(args, kwargs) -> None:
|
||||
"""
|
||||
Validate input params of sharded embeddingBag op.
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ class _ModMemStats:
|
||||
values as the memory consumed in bytes.
|
||||
"""
|
||||
|
||||
def __init__(self, mod_fqn: str):
|
||||
def __init__(self, mod_fqn: str) -> None:
|
||||
self.mod_fqn = mod_fqn
|
||||
self.parameter_mem: int
|
||||
self.buffer_mem: int
|
||||
|
||||
@ -157,7 +157,7 @@ class MemoryTracker:
|
||||
def show_traces(self, path: str = "") -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def _plot_figure(x, y_values, labels):
|
||||
def _plot_figure(x, y_values, labels) -> None:
|
||||
min_val = min(chain.from_iterable(y_values)) * 0.999
|
||||
max_val = max(chain.from_iterable(y_values)) * 1.001
|
||||
plt.figure()
|
||||
|
||||
@ -55,7 +55,7 @@ class ModTracker:
|
||||
A Set containing the fqn for each module currently running their forward
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.parents = {"Global"}
|
||||
self._active_module_cnt = {}
|
||||
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
@ -67,7 +67,7 @@ class ModTracker:
|
||||
self._user_pre_bw_hook = None
|
||||
self._user_post_bw_hook = None
|
||||
|
||||
def _maybe_set_engine_callback(self):
|
||||
def _maybe_set_engine_callback(self) -> None:
|
||||
# This assumes no concurrent calls to backward
|
||||
if self._has_callback:
|
||||
return
|
||||
@ -76,7 +76,7 @@ class ModTracker:
|
||||
torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback)
|
||||
self._post_bw_callbacks_to_enqueue.clear()
|
||||
|
||||
def callback():
|
||||
def callback() -> None:
|
||||
self.parents = {"Global"}
|
||||
self._has_callback = False
|
||||
|
||||
@ -102,7 +102,7 @@ class ModTracker:
|
||||
post_fw_hook: Optional[Callable] = None,
|
||||
pre_bw_hook: Optional[Callable] = None,
|
||||
post_bw_hook: Optional[Callable] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Registers user-specified hooks to be called before/after the forward/backward pass for each
|
||||
module tracked by the ``ModTracker``. One or more can be ``None``.
|
||||
@ -149,7 +149,7 @@ class ModTracker:
|
||||
post_bw_hook, self._user_post_bw_hook, "post_bw_hook"
|
||||
)
|
||||
|
||||
def clear_user_hooks(self):
|
||||
def clear_user_hooks(self) -> None:
|
||||
"""
|
||||
Clears the user specified hooks registered with ``register_user_hooks``
|
||||
"""
|
||||
@ -170,12 +170,14 @@ class ModTracker:
|
||||
return mod_name
|
||||
|
||||
def _get_append_fn(self, w_mod, name, is_bw):
|
||||
def fn(*args):
|
||||
def fn(*args) -> None:
|
||||
if is_bw:
|
||||
self._maybe_set_engine_callback()
|
||||
if name in self.parents and not self.is_bw:
|
||||
|
||||
def custom_formatwarning(msg, category, filename, lineno, line=None):
|
||||
def custom_formatwarning(
|
||||
msg, category, filename, lineno, line=None
|
||||
) -> str:
|
||||
return f"{filename}:{lineno}: {category.__name__}: {msg} \n"
|
||||
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
@ -197,7 +199,7 @@ class ModTracker:
|
||||
return fn
|
||||
|
||||
def _get_pop_fn(self, w_mod, name, is_bw):
|
||||
def fn(*args):
|
||||
def fn(*args) -> None:
|
||||
if self._user_post_bw_hook is not None and is_bw:
|
||||
self._user_post_bw_hook(w_mod(), args)
|
||||
if name in self.parents:
|
||||
@ -213,7 +215,7 @@ class ModTracker:
|
||||
|
||||
return fn
|
||||
|
||||
def _fw_pre_hook(self, mod, input):
|
||||
def _fw_pre_hook(self, mod, input) -> None:
|
||||
name = self._get_mod_name(mod)
|
||||
w_mod = weakref.ref(mod)
|
||||
self._get_append_fn(w_mod, name, False)()
|
||||
@ -229,7 +231,7 @@ class ModTracker:
|
||||
self._get_pop_fn(w_mod, name, True)
|
||||
)
|
||||
|
||||
def _fw_post_hook(self, mod, input, output):
|
||||
def _fw_post_hook(self, mod, input, output) -> None:
|
||||
name = self._get_mod_name(mod)
|
||||
w_mod = weakref.ref(mod)
|
||||
if self._user_post_fw_hook is not None:
|
||||
|
||||
@ -21,7 +21,7 @@ class DefaultState:
|
||||
"gradient_postdivide_factor",
|
||||
]
|
||||
|
||||
def __init__(self, process_group: dist.ProcessGroup):
|
||||
def __init__(self, process_group: dist.ProcessGroup) -> None:
|
||||
if process_group is None:
|
||||
raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
|
||||
self.process_group = process_group
|
||||
@ -64,12 +64,12 @@ class LowPrecisionState(DefaultState):
|
||||
self,
|
||||
process_group,
|
||||
parameter_type=torch.float32,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(process_group)
|
||||
self.parameter_type = parameter_type
|
||||
|
||||
|
||||
def _decompress(state: LowPrecisionState, grad: torch.Tensor):
|
||||
def _decompress(state: LowPrecisionState, grad: torch.Tensor) -> None:
|
||||
"""
|
||||
Casts gradients back to full parameter precision so that further computation happens in full precision.
|
||||
"""
|
||||
@ -92,7 +92,7 @@ def _decompress(state: LowPrecisionState, grad: torch.Tensor):
|
||||
orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def allreduce_hook(state: DefaultState, grad: torch.Tensor):
|
||||
def allreduce_hook(state: DefaultState, grad: torch.Tensor) -> None:
|
||||
r"""
|
||||
Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients.
|
||||
|
||||
@ -111,7 +111,9 @@ def allreduce_hook(state: DefaultState, grad: torch.Tensor):
|
||||
grad.div_(state.gradient_postdivide_factor)
|
||||
|
||||
|
||||
def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor):
|
||||
def reduce_scatter_hook(
|
||||
state: DefaultState, grad: torch.Tensor, output: torch.Tensor
|
||||
) -> None:
|
||||
r"""
|
||||
Implement the FSDP communication hook for ``reduce_scatter`` algorithm.
|
||||
|
||||
@ -137,7 +139,7 @@ def _low_precision_hook(
|
||||
state: LowPrecisionState,
|
||||
grad: torch.Tensor,
|
||||
output: Optional[torch.Tensor],
|
||||
):
|
||||
) -> None:
|
||||
if grad.dtype != prec:
|
||||
grad.data = grad.data.to(prec)
|
||||
if output is not None:
|
||||
|
||||
@ -65,7 +65,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
|
||||
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
|
||||
self._opt_hook_state = _OptimizerHookState(f_optim, params)
|
||||
|
||||
def register_ddp(self, ddp_inst: DistributedDataParallel):
|
||||
def register_ddp(self, ddp_inst: DistributedDataParallel) -> None:
|
||||
# NOTE: using a custom communication hook and fused optimizer is not
|
||||
# yet supported.
|
||||
ddp_inst.register_comm_hook( # type: ignore[operator]
|
||||
|
||||
@ -106,7 +106,7 @@ def auto_quantize(func, qtype, quant_loss=None):
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args, **kwargs) -> None:
|
||||
group = kwargs.get("group")
|
||||
async_op = kwargs.get("async_op", False)
|
||||
if async_op is True:
|
||||
|
||||
@ -30,7 +30,7 @@ from . import (
|
||||
__all__ = ["DDPCommHookType", "register_ddp_comm_hook"]
|
||||
|
||||
|
||||
def _ddp_comm_hook_wrapper(comm_hook, model, state):
|
||||
def _ddp_comm_hook_wrapper(comm_hook, model, state) -> None:
|
||||
model.register_comm_hook(state, comm_hook)
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ def _powerSGD_comm_hook_wrapper(
|
||||
state,
|
||||
matrix_approximation_rank,
|
||||
start_powerSGD_iter=1_000,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Wrap PowerSGD communication hook.
|
||||
|
||||
@ -123,7 +123,7 @@ class DDPCommHookType(Enum):
|
||||
)
|
||||
|
||||
|
||||
def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None):
|
||||
def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None) -> None:
|
||||
"""
|
||||
Register ``ddp_comm_hooks`` to DDP model.
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ def _perform_local_step(
|
||||
bucket: dist.GradBucket,
|
||||
zero: ZeroRedundancyOptimizer,
|
||||
rank: int,
|
||||
):
|
||||
) -> None:
|
||||
r"""
|
||||
Perform a local optimizer step using the gradients provided by ``bucket``.
|
||||
|
||||
@ -67,7 +67,7 @@ def _perform_local_step(
|
||||
def _broadcast_bucket(
|
||||
bucket_index: int,
|
||||
zero: ZeroRedundancyOptimizer,
|
||||
):
|
||||
) -> None:
|
||||
r"""
|
||||
Broadcasts a bucket's parameters.
|
||||
|
||||
@ -104,7 +104,7 @@ def _broadcast_bucket(
|
||||
def _save_ddp_bucket_info(
|
||||
bucket: dist.GradBucket,
|
||||
zero: ZeroRedundancyOptimizer,
|
||||
):
|
||||
) -> None:
|
||||
r"""
|
||||
Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``.
|
||||
|
||||
@ -136,7 +136,7 @@ def _hook_with_zero_step_setup(
|
||||
ddp_ref: weakref.ReferenceType,
|
||||
zero: ZeroRedundancyOptimizer,
|
||||
bucket: dist.GradBucket,
|
||||
):
|
||||
) -> None:
|
||||
r"""
|
||||
Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ def _reducer_allreduce_and_upcast_hook(
|
||||
p.grad.data = p.grad.to(p.data.dtype)
|
||||
|
||||
# enqueue a callback to wait for this stream at end of backward
|
||||
def wait_for_stream_cb():
|
||||
def wait_for_stream_cb() -> None:
|
||||
torch.accelerator.current_stream().wait_stream(stream)
|
||||
# Remove post-backward hooks since they are re-installed in next
|
||||
# iteration, similar to FSDP.
|
||||
|
||||
@ -23,16 +23,16 @@ class _OptimizerHookState:
|
||||
|
||||
__slots__ = ["functional_optimizer", "params_to_optimize"]
|
||||
|
||||
def __init__(self, functional_optim, params=None):
|
||||
def __init__(self, functional_optim, params=None) -> None:
|
||||
self.functional_optimizer = functional_optim
|
||||
self._check_valid_functional_optim()
|
||||
self._set_params_to_optimize(params)
|
||||
|
||||
def _set_params_to_optimize(self, params):
|
||||
def _set_params_to_optimize(self, params) -> None:
|
||||
if params is not None:
|
||||
self.params_to_optimize = set(params)
|
||||
|
||||
def _check_valid_functional_optim(self):
|
||||
def _check_valid_functional_optim(self) -> None:
|
||||
if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME):
|
||||
raise ValueError(
|
||||
f"Class {type(self.functional_optimizer)} must implement method "
|
||||
@ -99,7 +99,7 @@ def _apply_optim_in_backward_hook(
|
||||
|
||||
# enqueue a callback to wait for this optimizer stream at the end of
|
||||
# backward and set all DDP managed grads to None.
|
||||
def wait_for_optim_stream_callback():
|
||||
def wait_for_optim_stream_callback() -> None:
|
||||
torch.accelerator.current_stream().wait_stream(
|
||||
optim_stream_state.optim_stream
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ class PostLocalSGDState:
|
||||
subgroup,
|
||||
start_localSGD_iter,
|
||||
post_local_gradient_allreduce=True,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize state object with given parameters and log when localSGD start."""
|
||||
logger.info(
|
||||
"Local SGD will be started after %s iterations", start_localSGD_iter
|
||||
@ -55,7 +55,7 @@ class PostLocalSGDState:
|
||||
# Iteration/step in the training loop.
|
||||
self.iter = 0
|
||||
|
||||
def maybe_increase_iter(self, bucket):
|
||||
def maybe_increase_iter(self, bucket) -> None:
|
||||
"""Track iterations and trigger log message at start of local SGD."""
|
||||
# Since bucket 0 is the last bucket to allreduce in an iteration.
|
||||
# Only increase `iter` when bucket 0 is processed.
|
||||
|
||||
@ -16,7 +16,7 @@ __all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _orthogonalize(matrices, epsilon=0):
|
||||
def _orthogonalize(matrices, epsilon=0) -> None:
|
||||
"""
|
||||
Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
|
||||
|
||||
@ -41,7 +41,7 @@ def _orthogonalize(matrices, epsilon=0):
|
||||
)
|
||||
|
||||
|
||||
def _orthogonalize_gram_schmidt(matrices, epsilon=0):
|
||||
def _orthogonalize_gram_schmidt(matrices, epsilon=0) -> None:
|
||||
"""
|
||||
Apply Gram-Schmidt procedure to orthogonalize a batch of matrices.
|
||||
|
||||
@ -103,7 +103,7 @@ def _should_compress(
|
||||
)
|
||||
|
||||
|
||||
def _report_compression_stats(bucket, state):
|
||||
def _report_compression_stats(bucket, state) -> None:
|
||||
"""Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state."""
|
||||
if bucket.is_last() and state.iter >= state.next_stats_report:
|
||||
stats = state.compression_stats()
|
||||
@ -188,7 +188,7 @@ class PowerSGDState:
|
||||
random_seed=0,
|
||||
compression_stats_logging_frequency=10_000,
|
||||
batch_tensors_with_same_shape: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
logger.info(
|
||||
"PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; "
|
||||
"min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; "
|
||||
@ -302,7 +302,7 @@ class PowerSGDState:
|
||||
for slot, value in state.items():
|
||||
setattr(self, slot, value)
|
||||
|
||||
def maybe_increase_iter(self, bucket):
|
||||
def maybe_increase_iter(self, bucket) -> None:
|
||||
"""Track iterations and trigger log message at start of local SGD."""
|
||||
# Since bucket 0 is the last bucket to allreduce in an iteration.
|
||||
# Only increase `iter` when bucket 0 is processed.
|
||||
|
||||
@ -99,7 +99,9 @@ class HierarchicalModelAverager(averagers.ModelAverager):
|
||||
`HierarchicalModelAverager` is experimental and subject to change.
|
||||
"""
|
||||
|
||||
def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
|
||||
def __init__(
|
||||
self, period_group_size_dict=None, warmup_steps=0, process_group=None
|
||||
) -> None:
|
||||
super().__init__(process_group)
|
||||
if not period_group_size_dict:
|
||||
raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
|
||||
@ -163,7 +165,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Averages parameters or parameter groups of an optimizer.
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ __all__ = [
|
||||
|
||||
def average_parameters(
|
||||
params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Averages all the given parameters.
|
||||
|
||||
@ -87,6 +87,6 @@ def average_parameters_or_parameter_groups(
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
) -> None:
|
||||
"""Averages parameters of a model or parameter groups of an optimizer."""
|
||||
average_parameters(iter(get_params_to_average(params)), process_group)
|
||||
|
||||
@ -46,7 +46,7 @@ class HybridModel(torch.nn.Module):
|
||||
servers.
|
||||
"""
|
||||
|
||||
def __init__(self, emb_rref_list, device):
|
||||
def __init__(self, emb_rref_list, device) -> None:
|
||||
super().__init__()
|
||||
self.emb_rref_list = emb_rref_list
|
||||
fc1 = torch.nn.Linear(512, 256)
|
||||
@ -85,7 +85,7 @@ def _retrieve_embedding_parameters(emb_rref):
|
||||
return [RRef(p) for p in emb_rref.local_value().parameters()]
|
||||
|
||||
|
||||
def _print_header():
|
||||
def _print_header() -> None:
|
||||
_print_cont("\n")
|
||||
_print_cont(" " * 10)
|
||||
for _ in [50, 75, 90, 95]:
|
||||
@ -93,7 +93,7 @@ def _print_header():
|
||||
_print_cont("\n")
|
||||
|
||||
|
||||
def _print_benchmark(prefix, nelem, measurements):
|
||||
def _print_benchmark(prefix, nelem, measurements) -> None:
|
||||
measurements = sorted(measurements)
|
||||
_print_cont(f"{prefix:8s}:")
|
||||
for p in [50, 75, 90, 95]:
|
||||
@ -102,7 +102,7 @@ def _print_benchmark(prefix, nelem, measurements):
|
||||
_print_cont("\n")
|
||||
|
||||
|
||||
def _print_cont(msg):
|
||||
def _print_cont(msg) -> None:
|
||||
print(msg, end="", flush=True)
|
||||
|
||||
|
||||
@ -201,7 +201,7 @@ def _run_trainer(emb_rref_list, rank):
|
||||
return rank, measurements, batch_size # type: ignore[possibly-undefined]
|
||||
|
||||
|
||||
def run_worker(rank, world_size):
|
||||
def run_worker(rank, world_size) -> None:
|
||||
r"""
|
||||
Initialize RPC, calls the function, and shuts down RPC.
|
||||
"""
|
||||
|
||||
@ -34,7 +34,7 @@ class _CheckpointRequestIdentifier:
|
||||
checkpoint_id: Union[str, os.PathLike, None]
|
||||
uuid: str
|
||||
|
||||
def __init__(self, checkpoint_id: Union[str, os.PathLike, None]):
|
||||
def __init__(self, checkpoint_id: Union[str, os.PathLike, None]) -> None:
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.uuid = str(uuid4())
|
||||
|
||||
@ -58,7 +58,7 @@ class _ProcessGroupInitInfo:
|
||||
tcp_store_master_port: int
|
||||
use_prefix_store: bool
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
self.local_rank = dist.get_node_local_rank(fallback_rank=0)
|
||||
self.global_rank = dist.get_rank(process_group)
|
||||
self.world_size = dist.get_world_size(process_group)
|
||||
@ -103,7 +103,7 @@ class _AsyncCheckpointProcess:
|
||||
def __init__(
|
||||
self,
|
||||
pg_init_info: _ProcessGroupInitInfo,
|
||||
):
|
||||
) -> None:
|
||||
self.ctx = mp.get_context("spawn")
|
||||
self._process_pipe, child_end = self.ctx.Pipe()
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class _Checkpointer:
|
||||
no_dist: bool = False,
|
||||
load_planner: Optional[LoadPlanner] = None,
|
||||
save_planner: Optional[SavePlanner] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initializes the Checkpointer instance.
|
||||
|
||||
Args:
|
||||
|
||||
@ -105,7 +105,7 @@ class Barrier(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, **kwargs: dict[str, Any]):
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
"""
|
||||
Initialize a barrier.
|
||||
|
||||
@ -185,7 +185,7 @@ class TCPStoreBarrier(Barrier):
|
||||
tcpstore_port: int,
|
||||
master_address: str,
|
||||
timeout_secs: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TCPStoreBarrier.
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class CheckpointProcess:
|
||||
subprocess_init_args: tuple[Any, ...],
|
||||
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
|
||||
checkpoint_writer_init_args: dict[str, Any],
|
||||
):
|
||||
) -> None:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._rank_info = rank_info
|
||||
self._config = config
|
||||
|
||||
@ -32,7 +32,7 @@ class CheckpointReader:
|
||||
def __init__(
|
||||
self,
|
||||
rank_info: RankInfo,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a CheckpointReader.
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class CheckpointWriter:
|
||||
rank_info: RankInfo,
|
||||
barrier: Optional[Barrier] = None,
|
||||
commit_hook: Optional[WriterHook] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a CheckpointWriter.
|
||||
|
||||
|
||||
@ -110,7 +110,7 @@ class SyncCheckpointer(Checkpointer):
|
||||
self,
|
||||
writer: CheckpointWriter,
|
||||
reader: CheckpointReader,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a synchronous checkpointer.
|
||||
|
||||
@ -225,7 +225,7 @@ class AsyncCheckpointer(Checkpointer):
|
||||
checkpoint_stager: CheckpointStager,
|
||||
checkpoint_process: CheckpointProcess,
|
||||
reader: CheckpointReader,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an asynchronous checkpointer.
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ class DefaultStager(CheckpointStager):
|
||||
def __init__(
|
||||
self,
|
||||
config: CheckpointStagerConfig = CheckpointStagerConfig(),
|
||||
):
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._state_dict_stager = StateDictStager(
|
||||
pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory
|
||||
|
||||
@ -32,7 +32,7 @@ class StateDictStager:
|
||||
pin_memory: bool = False,
|
||||
share_memory: bool = False,
|
||||
pin_memory_min_bytes: int = 5,
|
||||
):
|
||||
) -> None:
|
||||
if pin_memory and not torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"Ignoring pin_memory flag for checkpoint staging as pinning memory"
|
||||
@ -258,7 +258,7 @@ class StateDictStager:
|
||||
|
||||
return y
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Clean up all cached storages and release associated resources.
|
||||
|
||||
@ -386,7 +386,7 @@ class StateDictStager:
|
||||
self._keep_alive(x, memo) # Make sure x lives at least as long as d
|
||||
return y
|
||||
|
||||
def _keep_alive(self, x, memo):
|
||||
def _keep_alive(self, x, memo) -> None:
|
||||
"""Keeps a reference to the object x in the memo.
|
||||
|
||||
Because we remember objects by their id, we have
|
||||
|
||||
@ -22,7 +22,7 @@ def _is_wrapped_exception(obj: Any) -> bool:
|
||||
class CheckpointException(BaseException):
|
||||
"""Exception raised if failure was detected as part of a checkpoint load or save."""
|
||||
|
||||
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]):
|
||||
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]) -> None:
|
||||
super().__init__(msg, failures)
|
||||
self._failures = failures
|
||||
|
||||
|
||||
@ -401,7 +401,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, keys=None, *args, **kwargs):
|
||||
def __init__(self, keys=None, *args, **kwargs) -> None:
|
||||
self.keys = keys
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ def _init_model(rank, world_size):
|
||||
return model, optim
|
||||
|
||||
|
||||
def _print(msg):
|
||||
def _print(msg) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
print(msg)
|
||||
|
||||
@ -80,7 +80,7 @@ def _input():
|
||||
return x, y
|
||||
|
||||
|
||||
def run(rank, world_size):
|
||||
def run(rank, world_size) -> None:
|
||||
# Set up world pg
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
@ -55,7 +55,7 @@ def print_params(stage, model_1, model_2, optim_1, optim_2):
|
||||
)
|
||||
|
||||
|
||||
def run_fsdp_checkpoint_example(rank, world_size):
|
||||
def run_fsdp_checkpoint_example(rank, world_size) -> None:
|
||||
# Set up world pg
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
@ -39,7 +39,7 @@ class Model(torch.nn.Module):
|
||||
return torch.rand(8, 8, device="cuda")
|
||||
|
||||
|
||||
def _make_stateful(model, optim):
|
||||
def _make_stateful(model, optim) -> None:
|
||||
_patch_model_state_dict(model)
|
||||
_patch_optimizer_state_dict(model, optimizers=optim)
|
||||
|
||||
@ -70,7 +70,7 @@ def _init_model(device, world_size):
|
||||
return model, optim
|
||||
|
||||
|
||||
def run(rank, world_size, device="cuda"):
|
||||
def run(rank, world_size, device="cuda") -> None:
|
||||
# Set up world pg
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
@ -258,7 +258,7 @@ class _StorageWriterTransforms:
|
||||
# are appended.
|
||||
|
||||
class NoCloseWriter(io.IOBase):
|
||||
def __init__(self, raw: io.IOBase):
|
||||
def __init__(self, raw: io.IOBase) -> None:
|
||||
self.raw = raw
|
||||
|
||||
def writeable(self) -> bool:
|
||||
@ -267,7 +267,7 @@ class _StorageWriterTransforms:
|
||||
def write(self, b: Buffer) -> int:
|
||||
return self.raw.write(b)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.flush()
|
||||
self.raw.flush()
|
||||
# but not close.
|
||||
|
||||
@ -208,7 +208,7 @@ class DynamicMetaLoadPlanner(DefaultLoadPlanner):
|
||||
def dcp_to_torch_save(
|
||||
dcp_checkpoint_dir: Union[str, os.PathLike],
|
||||
torch_save_path: Union[str, os.PathLike],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Given a directory containing a DCP checkpoint, this function will convert it into a
|
||||
Torch save file.
|
||||
@ -233,7 +233,7 @@ def dcp_to_torch_save(
|
||||
def torch_save_to_dcp(
|
||||
torch_save_path: Union[str, os.PathLike],
|
||||
dcp_checkpoint_dir: Union[str, os.PathLike],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Given the location of a torch save file, converts it into a DCP checkpoint.
|
||||
|
||||
|
||||
@ -110,7 +110,7 @@ def _dcp_method_logger(
|
||||
return decorator
|
||||
|
||||
|
||||
def _init_logger(rank: int):
|
||||
def _init_logger(rank: int) -> None:
|
||||
logger.setLevel(logging.INFO)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
|
||||
@ -177,7 +177,7 @@ class MetadataIndex:
|
||||
fqn: str,
|
||||
offset: Optional[Sequence[int]] = None,
|
||||
index: Optional[int] = None,
|
||||
):
|
||||
) -> None:
|
||||
# We must use object.__setattr__ due to frozen=True
|
||||
object.__setattr__(self, "fqn", fqn)
|
||||
object.__setattr__(self, "index", index)
|
||||
|
||||
@ -34,7 +34,7 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
||||
thread_count: int = 1,
|
||||
target_dtype: torch.dtype = torch.float32,
|
||||
block_size: int = 128,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the HuggingFace storage reader to load quantized checkpoints
|
||||
|
||||
@ -66,7 +66,7 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
||||
|
||||
return metadata
|
||||
|
||||
def _load_quantization_metadata(self):
|
||||
def _load_quantization_metadata(self) -> None:
|
||||
"""Load quantization metadata from the checkpoint."""
|
||||
checkpoint_path = Path(self.path)
|
||||
# Load weight mapping from index file
|
||||
@ -77,7 +77,7 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
self._build_weight_scale_mapping(weight_map)
|
||||
|
||||
def _build_weight_scale_mapping(self, weight_map: dict[str, str]):
|
||||
def _build_weight_scale_mapping(self, weight_map: dict[str, str]) -> None:
|
||||
"""Analyze and build weight-scale tensor pairs from weight mapping."""
|
||||
# Store the complete weight map for file location lookups
|
||||
self._weight_map = weight_map
|
||||
|
||||
@ -174,7 +174,7 @@ class DefaultStager(AsyncStager):
|
||||
def __init__(
|
||||
self,
|
||||
config: StagingOptions = StagingOptions(),
|
||||
):
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._state_dict_stager = StateDictStager(
|
||||
pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory
|
||||
@ -292,7 +292,7 @@ class BlockingAsyncStager(AsyncStager):
|
||||
self,
|
||||
cache_staged_state_dict: bool = False,
|
||||
type_check: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the BlockingAsyncStager.
|
||||
|
||||
@ -352,7 +352,7 @@ class _ReplicationStager(AsyncStager):
|
||||
timeout: timedelta = timedelta(minutes=30),
|
||||
device: torch.device = torch.device("cpu"),
|
||||
storage_dir: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
self._pg = pg
|
||||
self._timeout = timeout
|
||||
# pyrefly: ignore [read-only]
|
||||
|
||||
@ -1560,7 +1560,7 @@ def _patch_model_state_dict(
|
||||
options=options,
|
||||
)
|
||||
|
||||
def load_state_dict_call(state_dict: dict[str, Any]):
|
||||
def load_state_dict_call(state_dict: dict[str, Any]) -> None:
|
||||
_load_state_dict_call(model_state_dict=state_dict)
|
||||
|
||||
model.load_state_dict = load_state_dict_call
|
||||
@ -1619,7 +1619,7 @@ def _patch_optimizer_state_dict(
|
||||
options=options,
|
||||
)
|
||||
|
||||
def load_state_dict_call(state_dict: dict[str, Any]):
|
||||
def load_state_dict_call(state_dict: dict[str, Any]) -> None:
|
||||
_load_state_dict_call(optim_state_dict=state_dict)
|
||||
|
||||
_patched_state_dict.add(state_dict_call)
|
||||
|
||||
@ -287,7 +287,7 @@ def _load_state_dict(
|
||||
central_plan = global_plan[0]
|
||||
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def read_data():
|
||||
def read_data() -> None:
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
if central_plan is None:
|
||||
|
||||
@ -344,7 +344,7 @@ def async_save(
|
||||
def callback(
|
||||
original_staging_future: Future[STATE_DICT_TYPE],
|
||||
return_staging_future: Future[None] = return_staging_future,
|
||||
):
|
||||
) -> None:
|
||||
try:
|
||||
original_staging_future.result()
|
||||
return_staging_future.set_result(None)
|
||||
@ -363,7 +363,7 @@ def async_save(
|
||||
else:
|
||||
|
||||
@_dcp_method_logger(log_exceptions=True)
|
||||
def maybe_synchronize_staging():
|
||||
def maybe_synchronize_staging() -> None:
|
||||
if async_stager.should_synchronize_after_execute:
|
||||
async_stager.synchronize_staging()
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class _DistWrapper:
|
||||
group: Optional[dist.ProcessGroup],
|
||||
use_dist: bool,
|
||||
coordinator_rank: int,
|
||||
):
|
||||
) -> None:
|
||||
self.group = group
|
||||
self.use_dist = use_dist
|
||||
self.coordinator_rank = coordinator_rank
|
||||
@ -383,7 +383,7 @@ def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]:
|
||||
|
||||
|
||||
class _ReaderView(io.IOBase):
|
||||
def __init__(self, base_stream: io.IOBase, offset: int, len: int):
|
||||
def __init__(self, base_stream: io.IOBase, offset: int, len: int) -> None:
|
||||
super().__init__()
|
||||
self.offset = offset
|
||||
self.len = len
|
||||
|
||||
@ -29,7 +29,7 @@ if not is_available():
|
||||
class _DeviceMeshStub:
|
||||
pass
|
||||
|
||||
def _init_device_mesh_stub():
|
||||
def _init_device_mesh_stub() -> None:
|
||||
pass
|
||||
|
||||
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
|
||||
|
||||
@ -372,7 +372,7 @@ class Backend(str): # noqa: SLOT000
|
||||
class BackendConfig:
|
||||
"""Backend configuration class."""
|
||||
|
||||
def __init__(self, backend: Backend):
|
||||
def __init__(self, backend: Backend) -> None:
|
||||
"""Init."""
|
||||
self.device_backend_map: dict[str, Backend] = {}
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
@ -442,7 +442,7 @@ class BackendConfig:
|
||||
|
||||
logger.info("Using backend config: %s", self.device_backend_map)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return all the device:backend pairs separated by commas."""
|
||||
return ",".join(
|
||||
f"{device}:{backend}" for device, backend in self.device_backend_map.items()
|
||||
@ -508,7 +508,7 @@ class P2POp:
|
||||
group: Optional[ProcessGroup] = None,
|
||||
tag: int = 0,
|
||||
group_peer: Optional[int] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Init."""
|
||||
self.op = op
|
||||
self.tensor = tensor
|
||||
@ -534,7 +534,7 @@ class P2POp:
|
||||
|
||||
return object.__new__(cls)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
my_group_rank = get_rank(self.group)
|
||||
op_name = self.op.__name__
|
||||
group_name = self.group.group_name if self.group else "default_pg"
|
||||
@ -569,7 +569,7 @@ class _CollOp:
|
||||
dst_tensor: Optional[torch.Tensor] = None,
|
||||
redop: Optional[ReduceOp] = None,
|
||||
root: Optional[int] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.op = op
|
||||
self.tensor = tensor
|
||||
self.dst_tensor = dst_tensor
|
||||
@ -734,7 +734,7 @@ class _WorldMeta(type):
|
||||
return _world.default_pg
|
||||
|
||||
@WORLD.setter
|
||||
def WORLD(cls, pg: Optional[ProcessGroup]):
|
||||
def WORLD(cls, pg: Optional[ProcessGroup]) -> None:
|
||||
_world.default_pg = pg
|
||||
|
||||
|
||||
@ -1180,7 +1180,7 @@ def _canonicalize_group_rank(
|
||||
return group_rank
|
||||
|
||||
|
||||
def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str):
|
||||
def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str) -> None:
|
||||
if group.rank() == rank:
|
||||
raise ValueError(
|
||||
f"Invalid {rank_type} rank: {rank_type} rank should not be the same as "
|
||||
@ -1832,7 +1832,7 @@ def init_process_group(
|
||||
old_hook = sys.excepthook
|
||||
excepthook_prefix = f"[rank{get_rank()}]"
|
||||
|
||||
def _distributed_excepthook(*args):
|
||||
def _distributed_excepthook(*args) -> None:
|
||||
old_stderr = sys.stderr
|
||||
sys.stderr = buf = io.StringIO()
|
||||
try:
|
||||
@ -2216,7 +2216,7 @@ def _new_process_group_helper(
|
||||
return pg, prefix_store
|
||||
|
||||
|
||||
def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
def destroy_process_group(group: Optional[ProcessGroup] = None) -> None:
|
||||
"""
|
||||
Destroy a given process group, and deinitialize the distributed package.
|
||||
|
||||
@ -2305,7 +2305,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
|
||||
def _abort_process_group(group: Optional[ProcessGroup] = None):
|
||||
def _abort_process_group(group: Optional[ProcessGroup] = None) -> None:
|
||||
"""
|
||||
Abort a given process group. If group.WORLD (i.e. `None`) is given, all
|
||||
process groups including the default one will be aborted.
|
||||
@ -2623,11 +2623,11 @@ class _CoalescingManager:
|
||||
def __init__(self) -> None:
|
||||
self.works: list[Work] = []
|
||||
|
||||
def append(self, work: Optional[Work] = None):
|
||||
def append(self, work: Optional[Work] = None) -> None:
|
||||
if work:
|
||||
self.works.append(work)
|
||||
|
||||
def wait(self):
|
||||
def wait(self) -> None:
|
||||
for work in self.works:
|
||||
work.wait()
|
||||
|
||||
@ -3171,7 +3171,7 @@ def _tensor_to_object(tensor, tensor_size, group):
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def all_gather_object(object_list, obj, group=None):
|
||||
def all_gather_object(object_list, obj, group=None) -> None:
|
||||
"""
|
||||
Gathers picklable objects from the whole group into a list.
|
||||
|
||||
@ -3272,7 +3272,7 @@ def gather_object(
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group_dst: Optional[int] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Gathers picklable objects from the whole group in a single process.
|
||||
|
||||
@ -3404,7 +3404,7 @@ def send_object_list(
|
||||
device: Optional[torch.device] = None,
|
||||
group_dst: Optional[int] = None,
|
||||
use_batch: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Sends picklable objects in ``object_list`` synchronously.
|
||||
|
||||
@ -3663,7 +3663,7 @@ def broadcast_object_list(
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group_src: Optional[int] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Broadcasts picklable objects in ``object_list`` to the whole group.
|
||||
|
||||
@ -3795,7 +3795,7 @@ def scatter_object_list(
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group_src: Optional[int] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
|
||||
|
||||
@ -4250,7 +4250,7 @@ def all_gather_coalesced(
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
def _validate_output_list_for_rank(my_rank, dst, gather_list) -> None:
|
||||
if dst == my_rank:
|
||||
if not gather_list:
|
||||
raise ValueError(
|
||||
|
||||
@ -162,7 +162,7 @@ class Worker:
|
||||
role_rank: int = -1,
|
||||
world_size: int = -1,
|
||||
role_world_size: int = -1,
|
||||
):
|
||||
) -> None:
|
||||
# unique identifier for this worker
|
||||
self.id: Any = None
|
||||
|
||||
@ -188,14 +188,14 @@ class Worker:
|
||||
# the role world size may change between re-rendezvous.
|
||||
self.role_world_size: int = role_world_size
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"local_rank={self.local_rank},global_rank={self.global_rank}"
|
||||
f",role_rank={self.role_rank},world_size={self.world_size}"
|
||||
f",role_world_size={self.role_world_size}"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
@ -271,7 +271,7 @@ class WorkerGroup:
|
||||
"master_port",
|
||||
]
|
||||
|
||||
def __init__(self, spec: WorkerSpec):
|
||||
def __init__(self, spec: WorkerSpec) -> None:
|
||||
self.spec = spec
|
||||
self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
|
||||
|
||||
@ -295,7 +295,7 @@ class _RoleInstanceInfo:
|
||||
|
||||
__slots__ = ["role", "rank", "local_world_size"]
|
||||
|
||||
def __init__(self, role: str, rank: int, local_world_size: int):
|
||||
def __init__(self, role: str, rank: int, local_world_size: int) -> None:
|
||||
r"""Initialize the agent class instance.
|
||||
|
||||
Args:
|
||||
@ -449,7 +449,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
such as one particular type of worker role.
|
||||
"""
|
||||
|
||||
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
|
||||
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300) -> None:
|
||||
self._worker_group = WorkerGroup(spec)
|
||||
self._remaining_restarts = self._worker_group.spec.max_restarts
|
||||
self._store = None
|
||||
@ -845,7 +845,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
f"torchelastic.worker.status.{state}", source=source, metadata=metadata
|
||||
)
|
||||
|
||||
def _record_metrics(self, group_results: RunResult):
|
||||
def _record_metrics(self, group_results: RunResult) -> None:
|
||||
is_failed = group_results.is_failed()
|
||||
self._record_flakiness_metric(is_failed)
|
||||
spec = self._worker_group.spec
|
||||
@ -864,14 +864,14 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
"run_failed_no_retries", is_failed and not restarts_happened
|
||||
)
|
||||
|
||||
def _record_metric_with_condition(self, metric_name, condition):
|
||||
def _record_metric_with_condition(self, metric_name, condition) -> None:
|
||||
spec = self._worker_group.spec
|
||||
if condition:
|
||||
put_metric(f"workers.{spec.role}.{metric_name}", 1)
|
||||
else:
|
||||
put_metric(f"workers.{spec.role}.{metric_name}", 0)
|
||||
|
||||
def _record_flakiness_metric(self, is_failed: bool = False):
|
||||
def _record_flakiness_metric(self, is_failed: bool = False) -> None:
|
||||
if is_failed:
|
||||
flakiness = 100.0
|
||||
else:
|
||||
@ -952,7 +952,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
f"[{role}] Worker group in {state.name} state"
|
||||
)
|
||||
|
||||
def _exit_barrier(self):
|
||||
def _exit_barrier(self) -> None:
|
||||
"""
|
||||
Define a barrier that keeps the agent process alive until all workers finish.
|
||||
|
||||
|
||||
@ -153,7 +153,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
start_method="spawn",
|
||||
exit_barrier_timeout: float = 300,
|
||||
log_line_prefix_template: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(spec, exit_barrier_timeout)
|
||||
self._start_method = start_method
|
||||
self._pcontext: Optional[PContext] = None
|
||||
|
||||
@ -44,7 +44,7 @@ class Event:
|
||||
timestamp: int = 0
|
||||
metadata: dict[str, EventMetadataValue] = field(default_factory=dict)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.serialize()
|
||||
|
||||
@staticmethod
|
||||
@ -99,7 +99,7 @@ class RdzvEvent:
|
||||
local_id: Optional[int] = None
|
||||
error_trace: str = ""
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.serialize()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -158,7 +158,7 @@ from .api import ( # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
def initialize_metrics(cfg: Optional[MetricsConfig] = None):
|
||||
def initialize_metrics(cfg: Optional[MetricsConfig] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value
|
||||
class MetricsConfig:
|
||||
__slots__ = ["params"]
|
||||
|
||||
def __init__(self, params: Optional[dict[str, str]] = None):
|
||||
def __init__(self, params: Optional[dict[str, str]] = None) -> None:
|
||||
self.params = params
|
||||
if self.params is None:
|
||||
self.params = {}
|
||||
@ -50,23 +50,23 @@ class MetricHandler(abc.ABC):
|
||||
|
||||
|
||||
class ConsoleMetricHandler(MetricHandler):
|
||||
def emit(self, metric_data: MetricData):
|
||||
def emit(self, metric_data: MetricData) -> None:
|
||||
print(
|
||||
f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
|
||||
)
|
||||
|
||||
|
||||
class NullMetricHandler(MetricHandler):
|
||||
def emit(self, metric_data: MetricData):
|
||||
def emit(self, metric_data: MetricData) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MetricStream:
|
||||
def __init__(self, group_name: str, handler: MetricHandler):
|
||||
def __init__(self, group_name: str, handler: MetricHandler) -> None:
|
||||
self.group_name = group_name
|
||||
self.handler = handler
|
||||
|
||||
def add_value(self, metric_name: str, metric_value: int):
|
||||
def add_value(self, metric_name: str, metric_value: int) -> None:
|
||||
self.handler.emit(
|
||||
MetricData(time.time(), self.group_name, metric_name, metric_value)
|
||||
)
|
||||
@ -77,7 +77,7 @@ _default_metrics_handler: MetricHandler = NullMetricHandler()
|
||||
|
||||
|
||||
# pyre-fixme[9]: group has type `str`; used as `None`.
|
||||
def configure(handler: MetricHandler, group: Optional[str] = None):
|
||||
def configure(handler: MetricHandler, group: Optional[str] = None) -> None:
|
||||
if group is None:
|
||||
global _default_metrics_handler
|
||||
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
|
||||
@ -188,7 +188,9 @@ def profile(group=None):
|
||||
return wrap
|
||||
|
||||
|
||||
def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
|
||||
def put_metric(
|
||||
metric_name: str, metric_value: int, metric_group: str = "torchelastic"
|
||||
) -> None:
|
||||
"""
|
||||
Publish a metric data point.
|
||||
|
||||
@ -206,7 +208,7 @@ def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchel
|
||||
"Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def publish_metric(metric_group: str, metric_name: str, metric_value: int):
|
||||
def publish_metric(metric_group: str, metric_name: str, metric_value: int) -> None:
|
||||
metric_stream = getStream(metric_group)
|
||||
metric_stream.add_value(metric_name, metric_value)
|
||||
|
||||
|
||||
@ -102,7 +102,7 @@ def _get_default_signal() -> signal.Signals:
|
||||
return signal.SIGTERM
|
||||
|
||||
|
||||
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str):
|
||||
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str) -> None:
|
||||
actual_keys = set(d.keys())
|
||||
expected_keys = set(range(nprocs))
|
||||
|
||||
@ -472,7 +472,7 @@ class PContext(abc.ABC):
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.name = name
|
||||
# validate that all mappings have the same number of keys and
|
||||
# all local ranks are accounted for
|
||||
@ -715,7 +715,7 @@ class MultiprocessContext(PContext):
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
entrypoint,
|
||||
@ -743,7 +743,7 @@ class MultiprocessContext(PContext):
|
||||
|
||||
self._numa_options: Optional[NumaOptions] = numa_options
|
||||
|
||||
def _start(self):
|
||||
def _start(self) -> None:
|
||||
if self._pc:
|
||||
raise ValueError(
|
||||
"The process context already initialized."
|
||||
@ -904,7 +904,7 @@ class SubprocessContext(PContext):
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
entrypoint,
|
||||
@ -922,7 +922,7 @@ class SubprocessContext(PContext):
|
||||
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
|
||||
self._numa_options: Optional[NumaOptions] = numa_options
|
||||
|
||||
def _start(self):
|
||||
def _start(self) -> None:
|
||||
if self.subprocess_handlers:
|
||||
raise ValueError(
|
||||
"The subprocess handlers already initialized. Most likely the start method got called twice."
|
||||
@ -940,7 +940,7 @@ class SubprocessContext(PContext):
|
||||
for local_rank in range(self.nprocs)
|
||||
}
|
||||
|
||||
def _capture_process_failures(self, done_local_ranks: set[int]):
|
||||
def _capture_process_failures(self, done_local_ranks: set[int]) -> None:
|
||||
for local_rank in self._running_local_ranks:
|
||||
handler = self.subprocess_handlers[local_rank]
|
||||
exitcode = handler.proc.poll()
|
||||
|
||||
@ -157,7 +157,7 @@ class ProcessFailure:
|
||||
timestamp = int(message["extraInfo"]["timestamp"])
|
||||
return (message, timestamp)
|
||||
|
||||
def _set_no_reply_file(self):
|
||||
def _set_no_reply_file(self) -> None:
|
||||
self.error_file = _NOT_AVAILABLE
|
||||
self.error_file_data = _EMPTY_ERROR_DATA
|
||||
self.message = ""
|
||||
@ -237,7 +237,7 @@ class ChildFailedError(Exception):
|
||||
of trainer 1's error file to the scheduler's init process.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]):
|
||||
def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]) -> None:
|
||||
self.name = name
|
||||
self.failures = failures
|
||||
assert (
|
||||
|
||||
@ -92,7 +92,7 @@ class ErrorHandler:
|
||||
rootcause_error_file: str,
|
||||
rootcause_error: dict[str, Any],
|
||||
error_code: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
"""Modify the rootcause_error read from the file, to correctly set the exit code."""
|
||||
if "message" not in rootcause_error:
|
||||
logger.warning(
|
||||
@ -110,7 +110,7 @@ class ErrorHandler:
|
||||
else:
|
||||
rootcause_error["message"]["errorCode"] = error_code
|
||||
|
||||
def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
|
||||
def dump_error_file(self, rootcause_error_file: str, error_code: int = 0) -> None:
|
||||
"""Dump parent error file from child process's root cause error and error code."""
|
||||
with open(rootcause_error_file) as fp:
|
||||
rootcause_error = json.load(fp)
|
||||
@ -147,7 +147,7 @@ class ErrorHandler:
|
||||
rootcause_error_file,
|
||||
)
|
||||
|
||||
def _rm(self, my_error_file):
|
||||
def _rm(self, my_error_file) -> None:
|
||||
if os.path.isfile(my_error_file):
|
||||
# Log the contents of the original file.
|
||||
with open(my_error_file) as fp:
|
||||
|
||||
@ -87,7 +87,7 @@ def redirect(std: str, to_file: str):
|
||||
python_std = _python_std(std)
|
||||
std_fd = python_std.fileno()
|
||||
|
||||
def _redirect(dst):
|
||||
def _redirect(dst) -> None:
|
||||
libc.fflush(c_std)
|
||||
python_std.flush()
|
||||
os.dup2(dst.fileno(), std_fd)
|
||||
|
||||
@ -42,7 +42,7 @@ class SubprocessHandler:
|
||||
stderr: Optional[str],
|
||||
local_rank_id: int,
|
||||
numa_options: Optional[NumaOptions],
|
||||
):
|
||||
) -> None:
|
||||
self._stdout = open(stdout, "w") if stdout else None
|
||||
self._stderr = open(stderr, "w") if stderr else None
|
||||
# inherit parent environment vars
|
||||
|
||||
@ -32,7 +32,7 @@ def tail_logfile(
|
||||
finished: Event,
|
||||
interval_sec: float,
|
||||
log_line_filter: Optional[Callable[[str], bool]] = None,
|
||||
):
|
||||
) -> None:
|
||||
while not os.path.exists(file):
|
||||
if finished.is_set():
|
||||
return
|
||||
@ -102,7 +102,7 @@ class TailLog:
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
interval_sec: float = 0.1,
|
||||
log_line_filter: Callable[[str], bool] = (lambda _: True),
|
||||
):
|
||||
) -> None:
|
||||
n = len(log_files)
|
||||
self._threadpool = None
|
||||
if n > 0:
|
||||
|
||||
@ -115,7 +115,7 @@ class RendezvousInfo:
|
||||
rank: int,
|
||||
world_size: int,
|
||||
bootstrap_store_info: RendezvousStoreInfo,
|
||||
):
|
||||
) -> None:
|
||||
self._store = store
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
@ -267,7 +267,7 @@ class RendezvousParameters:
|
||||
max_nodes: int,
|
||||
local_addr: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if not backend:
|
||||
raise ValueError("The rendezvous backend name must be a non-empty string.")
|
||||
|
||||
|
||||
@ -183,7 +183,7 @@ class RendezvousTimeout:
|
||||
"""Get the keep-alive heartbeat timeout."""
|
||||
return self._heartbeat
|
||||
|
||||
def _set_timeouts(self, **timeouts: Optional[timedelta]):
|
||||
def _set_timeouts(self, **timeouts: Optional[timedelta]) -> None:
|
||||
for name, timeout in timeouts.items():
|
||||
if timeout is None:
|
||||
timeout = self._DEFAULT_TIMEOUTS[name]
|
||||
@ -395,7 +395,7 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
|
||||
self._last_sync_time = -1
|
||||
self._dead_nodes = []
|
||||
|
||||
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING):
|
||||
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
|
||||
construct_and_record_rdzv_event(
|
||||
name=f"{self.__class__.__name__}.{get_method_name()}",
|
||||
run_id=self._settings.run_id,
|
||||
|
||||
@ -153,7 +153,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
|
||||
+--------------------------------------------+--------------------------+
|
||||
"""
|
||||
|
||||
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]):
|
||||
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]) -> None:
|
||||
"""
|
||||
Args:
|
||||
rdzv_impl: the implementation of the rendezvous
|
||||
@ -163,7 +163,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
|
||||
self._rdzv_impl = rdzv_impl
|
||||
self._local_addr = local_addr
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
# TODO: look into using weakref here instead.
|
||||
del self._rdzv_impl
|
||||
|
||||
@ -189,7 +189,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
|
||||
# No rendezvous state, so it cannot be closed.
|
||||
return False
|
||||
|
||||
def set_closed(self):
|
||||
def set_closed(self) -> None:
|
||||
self._rdzv_impl.set_closed()
|
||||
|
||||
def num_nodes_waiting(self):
|
||||
@ -231,7 +231,7 @@ class EtcdRendezvous:
|
||||
num_max_workers,
|
||||
timeout,
|
||||
last_call_timeout,
|
||||
):
|
||||
) -> None:
|
||||
self.client = client
|
||||
logger.info("Etcd machines: %s", self.client.machines)
|
||||
|
||||
@ -270,7 +270,7 @@ class EtcdRendezvous:
|
||||
except etcd.EtcdAlreadyExist:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
# TODO: look into using weakref here instead.
|
||||
if self._lease_run_id_stop is not None:
|
||||
self._lease_run_id_stop.set()
|
||||
@ -442,7 +442,7 @@ class EtcdRendezvous:
|
||||
# Rendezvous version number; our rank in it; world size
|
||||
return state["version"], this_rank, len(state["participants"])
|
||||
|
||||
def handle_existing_rendezvous(self, expected_version):
|
||||
def handle_existing_rendezvous(self, expected_version) -> None:
|
||||
"""
|
||||
Handle the case when there's an existing (state 'final) rendezvous already
|
||||
in place, and we have to announce ourselves waiting, and wait until
|
||||
@ -678,7 +678,7 @@ class EtcdRendezvous:
|
||||
except etcd.EtcdCompareFailed:
|
||||
logger.info("Announce self as waiting CAS unsuccessful, retrying")
|
||||
|
||||
def wait_for_rendezvous_to_free(self, expected_version):
|
||||
def wait_for_rendezvous_to_free(self, expected_version) -> None:
|
||||
"""
|
||||
When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join.
|
||||
|
||||
@ -746,7 +746,7 @@ class EtcdRendezvous:
|
||||
raise RendezvousTimeoutError
|
||||
active_version, state = self.get_rdzv_state()
|
||||
|
||||
def handle_join_last_call(self, expected_version, deadline):
|
||||
def handle_join_last_call(self, expected_version, deadline) -> None:
|
||||
"""
|
||||
After we reach min number of workers, one particular worker takes on the
|
||||
responsibility of waiting an additional timeout before closing the join window.
|
||||
@ -820,7 +820,7 @@ class EtcdRendezvous:
|
||||
cas_delay()
|
||||
active_version, state = self.get_rdzv_state()
|
||||
|
||||
def set_closed(self):
|
||||
def set_closed(self) -> None:
|
||||
"""
|
||||
Mark rendezvous 'closed' for current run_id, which is used to signal other
|
||||
participants to not attempt to perform (re-)rendezvous. This is useful
|
||||
@ -868,13 +868,13 @@ class EtcdRendezvous:
|
||||
# Unfortunately, we have to do another fetch in order to get last etcd_index.
|
||||
return self.get_rdzv_state()
|
||||
|
||||
def get_path(self, path):
|
||||
def get_path(self, path) -> str:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return f"{self._prefix}run_{self._run_id}{path}"
|
||||
|
||||
def create_path_if_not_exists(self, full_path, ttl=None):
|
||||
def create_path_if_not_exists(self, full_path, ttl=None) -> None:
|
||||
try:
|
||||
self.client.write(
|
||||
key=full_path, value=None, dir=True, prevExist=False, ttl=ttl
|
||||
@ -888,7 +888,7 @@ class EtcdRendezvous:
|
||||
# release the Python's GIL! An example of this is calling a pybind11
|
||||
# extension function that is blocking / long-running, but is not
|
||||
# doing a scoped release of the GIL.
|
||||
def lease_worker(client, path, ttl, stop_event):
|
||||
def lease_worker(client, path, ttl, stop_event) -> None:
|
||||
while True:
|
||||
try:
|
||||
client.refresh(path, ttl=ttl)
|
||||
@ -912,7 +912,7 @@ class EtcdRendezvous:
|
||||
|
||||
return lease_stop_event
|
||||
|
||||
def store_extra_data(self, rdzv_version, key, value):
|
||||
def store_extra_data(self, rdzv_version, key, value) -> None:
|
||||
node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
|
||||
try:
|
||||
# If first time we are storing anything:
|
||||
|
||||
@ -64,7 +64,7 @@ def find_free_port():
|
||||
raise RuntimeError("Failed to create a socket")
|
||||
|
||||
|
||||
def stop_etcd(subprocess, data_dir: Optional[str] = None):
|
||||
def stop_etcd(subprocess, data_dir: Optional[str] = None) -> None:
|
||||
if subprocess and subprocess.poll() is None:
|
||||
logger.info("stopping etcd server")
|
||||
subprocess.terminate()
|
||||
@ -107,7 +107,7 @@ class EtcdServer:
|
||||
etcd_binary_path: path of etcd server binary (see above for fallback path)
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[str] = None):
|
||||
def __init__(self, data_dir: Optional[str] = None) -> None:
|
||||
self._port = -1
|
||||
self._host = "localhost"
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ except ModuleNotFoundError:
|
||||
|
||||
# Delay (sleep) for a small random amount to reduce CAS failures.
|
||||
# This does not affect correctness, but will reduce requests to etcd server.
|
||||
def cas_delay():
|
||||
def cas_delay() -> None:
|
||||
time.sleep(random.uniform(0, 0.1))
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class EtcdStore(Store):
|
||||
etcd_store_prefix,
|
||||
# Default timeout same as in c10d/Store.hpp
|
||||
timeout: Optional[datetime.timedelta] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__() # required for pybind trampoline.
|
||||
|
||||
self.client = etcd_client
|
||||
@ -53,7 +53,7 @@ class EtcdStore(Store):
|
||||
if not self.prefix.endswith("/"):
|
||||
self.prefix += "/"
|
||||
|
||||
def set(self, key, value):
|
||||
def set(self, key, value) -> None:
|
||||
"""
|
||||
Write a key/value pair into ``EtcdStore``.
|
||||
|
||||
@ -121,7 +121,7 @@ class EtcdStore(Store):
|
||||
except etcd.EtcdCompareFailed:
|
||||
cas_delay()
|
||||
|
||||
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
|
||||
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None) -> None:
|
||||
"""
|
||||
Wait until all of the keys are published, or until timeout.
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ class StaticTCPRendezvous(RendezvousHandler):
|
||||
world_size: int,
|
||||
run_id: str,
|
||||
timeout: int,
|
||||
):
|
||||
) -> None:
|
||||
self.master_addr = master_addr
|
||||
self.master_port = master_port
|
||||
self.rank = rank
|
||||
@ -82,13 +82,13 @@ class StaticTCPRendezvous(RendezvousHandler):
|
||||
bootstrap_store_info,
|
||||
)
|
||||
|
||||
def is_closed(self):
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
||||
def set_closed(self):
|
||||
def set_closed(self) -> None:
|
||||
pass
|
||||
|
||||
def num_nodes_waiting(self):
|
||||
def num_nodes_waiting(self) -> int:
|
||||
return 0
|
||||
|
||||
def get_run_id(self) -> str:
|
||||
|
||||
@ -279,7 +279,7 @@ class _PeriodicTimer:
|
||||
ctx.function(*ctx.args, **ctx.kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _stop_thread(thread, stop_event):
|
||||
def _stop_thread(thread, stop_event) -> None:
|
||||
stop_event.set()
|
||||
|
||||
thread.join()
|
||||
|
||||
@ -39,7 +39,7 @@ class TimerRequest:
|
||||
|
||||
__slots__ = ["worker_id", "scope_id", "expiration_time"]
|
||||
|
||||
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
|
||||
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float) -> None:
|
||||
self.worker_id = worker_id
|
||||
self.scope_id = scope_id
|
||||
self.expiration_time = expiration_time
|
||||
@ -119,7 +119,7 @@ class TimerServer(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
:param request_queue: Consumer ``RequestQueue``
|
||||
:param max_interval: max time (in seconds) to wait
|
||||
@ -179,14 +179,14 @@ class TimerServer(abc.ABC):
|
||||
)
|
||||
return True
|
||||
|
||||
def _watchdog_loop(self):
|
||||
def _watchdog_loop(self) -> None:
|
||||
while not self._stop_signaled:
|
||||
try:
|
||||
self._run_watchdog()
|
||||
except Exception:
|
||||
logger.exception("Error running watchdog")
|
||||
|
||||
def _run_watchdog(self):
|
||||
def _run_watchdog(self) -> None:
|
||||
batch_size = max(1, self._request_queue.size())
|
||||
timer_requests = self._request_queue.get(batch_size, self._max_interval)
|
||||
self.register_timers(timer_requests)
|
||||
@ -237,7 +237,7 @@ class TimerServer(abc.ABC):
|
||||
_timer_client: Optional[TimerClient] = None
|
||||
|
||||
|
||||
def configure(timer_client: TimerClient):
|
||||
def configure(timer_client: TimerClient) -> None:
|
||||
"""
|
||||
Configures a timer client. Must be called before using ``expires``.
|
||||
"""
|
||||
|
||||
@ -19,6 +19,6 @@ __all__ = ["log_debug_info_for_expired_timers"]
|
||||
def log_debug_info_for_expired_timers(
|
||||
run_id: str,
|
||||
expired_timers: dict[int, list[str]],
|
||||
):
|
||||
) -> None:
|
||||
if expired_timers:
|
||||
logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)
|
||||
|
||||
@ -263,7 +263,7 @@ class FileTimerServer:
|
||||
os.remove(self._file_path)
|
||||
|
||||
@staticmethod
|
||||
def is_process_running(pid: int):
|
||||
def is_process_running(pid: int) -> bool | None:
|
||||
"""
|
||||
function to check process is running or not
|
||||
"""
|
||||
|
||||
@ -29,16 +29,16 @@ class LocalTimerClient(TimerClient):
|
||||
GPU devices.
|
||||
"""
|
||||
|
||||
def __init__(self, mp_queue):
|
||||
def __init__(self, mp_queue) -> None:
|
||||
super().__init__()
|
||||
self._mp_queue = mp_queue
|
||||
|
||||
def acquire(self, scope_id, expiration_time):
|
||||
def acquire(self, scope_id, expiration_time) -> None:
|
||||
pid = os.getpid()
|
||||
acquire_request = TimerRequest(pid, scope_id, expiration_time)
|
||||
self._mp_queue.put(acquire_request)
|
||||
|
||||
def release(self, scope_id):
|
||||
def release(self, scope_id) -> None:
|
||||
pid = os.getpid()
|
||||
release_request = TimerRequest(pid, scope_id, -1)
|
||||
self._mp_queue.put(release_request)
|
||||
@ -49,7 +49,7 @@ class MultiprocessingRequestQueue(RequestQueue):
|
||||
A ``RequestQueue`` backed by python ``multiprocessing.Queue``
|
||||
"""
|
||||
|
||||
def __init__(self, mp_queue: mp.Queue):
|
||||
def __init__(self, mp_queue: mp.Queue) -> None:
|
||||
super().__init__()
|
||||
self._mp_queue = mp_queue
|
||||
|
||||
@ -86,7 +86,7 @@ class LocalTimerServer(TimerServer):
|
||||
|
||||
def __init__(
|
||||
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
|
||||
self._timers: dict[tuple[Any, str], TimerRequest] = {}
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ class CyclingIterator(Iterator[_T]):
|
||||
n: int,
|
||||
generator_fn: Callable[[int], Iterator[_T]],
|
||||
start_epoch: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
self._n = n
|
||||
self._epoch = start_epoch
|
||||
self._generator_fn = generator_fn
|
||||
|
||||
@ -47,7 +47,7 @@ class ElasticDistributedSampler(DistributedSampler[T]):
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
start_index: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
|
||||
if not isinstance(dataset, Sized):
|
||||
raise TypeError("Dataset must be an instance of collections.abc.Sized")
|
||||
|
||||
@ -115,7 +115,7 @@ def create_c10d_store(
|
||||
raise
|
||||
|
||||
|
||||
def _check_full_rank(store, world_size, timeout):
|
||||
def _check_full_rank(store, world_size, timeout) -> None:
|
||||
try:
|
||||
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
|
||||
except RuntimeError as e:
|
||||
|
||||
@ -3,7 +3,9 @@ import torch
|
||||
from torch.distributed._tools import MemoryTracker
|
||||
|
||||
|
||||
def run_one_model(net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"):
|
||||
def run_one_model(
|
||||
net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"
|
||||
) -> None:
|
||||
net.to(device)
|
||||
input = input.to(device)
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ class _FSDPDeviceHandle:
|
||||
semantics to be integrated with FSDP.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, backend: Any = None):
|
||||
def __init__(self, device: torch.device, backend: Any = None) -> None:
|
||||
if backend is None:
|
||||
try:
|
||||
self.__backend = getattr(torch, device.type)
|
||||
@ -187,7 +187,7 @@ class HandleTrainingState(Enum):
|
||||
SUMMON_FULL_PARAMS = auto()
|
||||
|
||||
|
||||
def _is_composable(state: _FSDPState):
|
||||
def _is_composable(state: _FSDPState) -> bool:
|
||||
# TODO: This is a temporary hack for differentiate between code paths.
|
||||
return not isinstance(state, nn.Module)
|
||||
|
||||
@ -305,7 +305,7 @@ def _get_param_to_fqns(
|
||||
includes the FQNs across all encounters. (Default: ``True``)
|
||||
"""
|
||||
|
||||
def module_fn(module, prefix, tree_level, param_to_fqns):
|
||||
def module_fn(module, prefix, tree_level, param_to_fqns) -> None:
|
||||
for param_name, param in _named_parameters_with_duplicates(
|
||||
module, recurse=False
|
||||
):
|
||||
@ -400,7 +400,9 @@ def _apply_to_modules(
|
||||
to remove the prefix.
|
||||
"""
|
||||
|
||||
def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
|
||||
def f(
|
||||
module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs
|
||||
) -> None:
|
||||
# Call the module function before recursing over children (pre-order)
|
||||
module_fn(module, prefix, tree_level, *args, **kwargs)
|
||||
for submodule_name, submodule in module.named_children():
|
||||
|
||||
@ -106,7 +106,7 @@ def _get_sharded_module_tree_with_module_name_to_fqns(
|
||||
|
||||
def module_fn(
|
||||
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
|
||||
):
|
||||
) -> None:
|
||||
num_spaces = tree_level * 4
|
||||
trimed_prefix = (
|
||||
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
|
||||
|
||||
@ -352,7 +352,7 @@ class _ExecOrderData:
|
||||
fqns.append(self.param_to_fqn[flat_param])
|
||||
return fqns
|
||||
|
||||
def next_iter(self):
|
||||
def next_iter(self) -> None:
|
||||
"""
|
||||
Advances the internal data structures per iteration. This should be
|
||||
called in the post-backward callback since that marks the true end of
|
||||
|
||||
@ -192,7 +192,7 @@ class FlatParamShardMetadata(NamedTuple):
|
||||
class _FlatParameterMeta(_ParameterMeta):
|
||||
# Make `isinstance(t, FlatParameter)` return True for custom tensor
|
||||
# instances that have the _is_flat_param flag for BC
|
||||
def __instancecheck__(self, instance):
|
||||
def __instancecheck__(self, instance) -> bool:
|
||||
# NB: do NOT test the super implementation
|
||||
return isinstance(instance, torch.Tensor) and getattr(
|
||||
instance, "_is_flat_param", False
|
||||
@ -525,7 +525,7 @@ class FlatParamHandle:
|
||||
use_orig_params: bool,
|
||||
*,
|
||||
fsdp_extension: Optional[FSDPExtensions] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
params = list(params)
|
||||
if len(params) == 0:
|
||||
@ -621,10 +621,10 @@ class FlatParamHandle:
|
||||
)
|
||||
self._use_unsharded_views(as_params=False)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"FlatParamHandle(flat_param.fqns={self.flat_param._fqns})"
|
||||
|
||||
def _init_setattr_fns(self):
|
||||
def _init_setattr_fns(self) -> None:
|
||||
use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
|
||||
self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
|
||||
self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
|
||||
@ -635,7 +635,7 @@ class FlatParamHandle:
|
||||
self._setattr_tensor = _safe_setattr_tensor_or_param
|
||||
self._setattr_param = _safe_setattr_tensor_or_param
|
||||
|
||||
def _init_get_unflat_views_fn(self, align_addresses: bool):
|
||||
def _init_get_unflat_views_fn(self, align_addresses: bool) -> None:
|
||||
self._get_unflat_views = (
|
||||
self._get_unflat_views_aligned
|
||||
if align_addresses
|
||||
@ -945,7 +945,7 @@ class FlatParamHandle:
|
||||
# SHARD INITIALIZATION & METADATA #
|
||||
###################################
|
||||
@torch.no_grad()
|
||||
def shard(self):
|
||||
def shard(self) -> None:
|
||||
"""
|
||||
Shard the handle's ``FlatParameter``.
|
||||
|
||||
@ -1342,7 +1342,7 @@ class FlatParamHandle:
|
||||
self._check_on_compute_device(self.flat_param)
|
||||
return ret
|
||||
|
||||
def _use_low_precision_shard(self):
|
||||
def _use_low_precision_shard(self) -> None:
|
||||
"""Allocate on the compute device and switch to using the low precision sharded flat parameter."""
|
||||
self._check_low_precision_shard()
|
||||
flat_param = self.flat_param
|
||||
@ -1359,7 +1359,7 @@ class FlatParamHandle:
|
||||
# Invariant: `_mp_shard` is always on the compute device.
|
||||
flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
|
||||
|
||||
def unshard(self):
|
||||
def unshard(self) -> None:
|
||||
"""
|
||||
Run the unshard logic.
|
||||
|
||||
@ -1523,7 +1523,7 @@ class FlatParamHandle:
|
||||
elif in_forward:
|
||||
self._use_unsharded_views(as_params=False)
|
||||
|
||||
def post_unshard(self):
|
||||
def post_unshard(self) -> None:
|
||||
"""
|
||||
Run the post-unshard logic.
|
||||
|
||||
@ -1533,7 +1533,7 @@ class FlatParamHandle:
|
||||
self._free_low_precision_sharded_param()
|
||||
self._check_on_compute_device(self.flat_param)
|
||||
|
||||
def _free_low_precision_sharded_param(self):
|
||||
def _free_low_precision_sharded_param(self) -> None:
|
||||
"""Frees the low precision sharded flat parameter."""
|
||||
self._check_low_precision_shard()
|
||||
# `_mp_shard` is allocated in the pre-unshard stream, consumed in the
|
||||
@ -1550,7 +1550,7 @@ class FlatParamHandle:
|
||||
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
|
||||
|
||||
@torch.no_grad()
|
||||
def unshard_grad(self):
|
||||
def unshard_grad(self) -> None:
|
||||
"""
|
||||
Unshard the handle's ``FlatParameter``'s gradient.
|
||||
|
||||
@ -1608,7 +1608,7 @@ class FlatParamHandle:
|
||||
)
|
||||
self._use_unsharded_grad_views()
|
||||
|
||||
def reshard_grad(self):
|
||||
def reshard_grad(self) -> None:
|
||||
if self._use_orig_params:
|
||||
self._use_sharded_grad_views()
|
||||
if not self.uses_sharded_strategy:
|
||||
@ -1616,7 +1616,7 @@ class FlatParamHandle:
|
||||
self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
|
||||
delattr(self.flat_param, "_saved_grad_shard")
|
||||
|
||||
def prepare_gradient_for_backward(self):
|
||||
def prepare_gradient_for_backward(self) -> None:
|
||||
"""
|
||||
Prepare the gradient for the backward computation.
|
||||
|
||||
@ -1681,10 +1681,10 @@ class FlatParamHandle:
|
||||
)
|
||||
flat_param.grad = None
|
||||
|
||||
def prepare_gradient_for_optim(self):
|
||||
def prepare_gradient_for_optim(self) -> None:
|
||||
"""Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
|
||||
|
||||
def cast_grad_to_param_dtype_if_needed(flat_param):
|
||||
def cast_grad_to_param_dtype_if_needed(flat_param) -> None:
|
||||
# TODO (rohan-varma): test for full precision with keep_low_precision_grads
|
||||
if not self._force_full_precision and self._keep_low_precision_grads:
|
||||
_p_assert(flat_param.grad is not None, "Unexpected None grad!")
|
||||
@ -1768,7 +1768,7 @@ class FlatParamHandle:
|
||||
)
|
||||
self._use_unsharded_flat_param(padded_unsharded_flat_param)
|
||||
|
||||
def reshard(self, free_unsharded_flat_param: bool):
|
||||
def reshard(self, free_unsharded_flat_param: bool) -> None:
|
||||
"""
|
||||
Run the reshard logic.
|
||||
|
||||
@ -1786,7 +1786,7 @@ class FlatParamHandle:
|
||||
if free_unsharded_flat_param:
|
||||
self._free_unsharded_flat_param()
|
||||
|
||||
def post_reshard(self):
|
||||
def post_reshard(self) -> None:
|
||||
"""
|
||||
Run the post-reshard logic.
|
||||
|
||||
@ -1807,7 +1807,7 @@ class FlatParamHandle:
|
||||
):
|
||||
self._free_low_precision_sharded_param()
|
||||
|
||||
def _free_unsharded_flat_param(self):
|
||||
def _free_unsharded_flat_param(self) -> None:
|
||||
"""
|
||||
Free the padded unsharded flat parameter. We allow this
|
||||
function to be called even when storage is not allocated
|
||||
@ -2451,7 +2451,7 @@ class FlatParamHandle:
|
||||
raise AssertionError("Expected _is_grad_none_mask to be not None")
|
||||
self.flat_param._is_grad_none_mask[tensor_index] = True
|
||||
|
||||
def _reset_flat_param_grad_info_if_needed(self):
|
||||
def _reset_flat_param_grad_info_if_needed(self) -> None:
|
||||
"""
|
||||
Reset ``flat_param.grad`` if needed.
|
||||
|
||||
@ -2480,7 +2480,7 @@ class FlatParamHandle:
|
||||
# must require gradient
|
||||
flat_param.requires_grad = requires_grad
|
||||
|
||||
def _deregister_orig_params(self):
|
||||
def _deregister_orig_params(self) -> None:
|
||||
for param_info in self.flat_param._param_infos:
|
||||
param_name, module, _ = param_info
|
||||
if hasattr(module, param_name):
|
||||
@ -2492,7 +2492,7 @@ class FlatParamHandle:
|
||||
###########
|
||||
# HELPERS #
|
||||
###########
|
||||
def flat_param_to(self, *args, **kwargs):
|
||||
def flat_param_to(self, *args, **kwargs) -> None:
|
||||
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
|
||||
# pyrefly: ignore [not-iterable]
|
||||
self.flat_param.data = self.flat_param.to(*args, **kwargs)
|
||||
@ -2628,23 +2628,23 @@ class FlatParamHandle:
|
||||
#######################
|
||||
# CHECKS & INVARIANTS #
|
||||
#######################
|
||||
def _check_sharded_strategy(self):
|
||||
def _check_sharded_strategy(self) -> None:
|
||||
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
|
||||
|
||||
def _check_on_compute_device(self, tensor: Tensor):
|
||||
def _check_on_compute_device(self, tensor: Tensor) -> None:
|
||||
_p_assert(
|
||||
tensor.device == self.device,
|
||||
f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
|
||||
)
|
||||
|
||||
def _check_on_cpu(self, tensor: Tensor):
|
||||
def _check_on_cpu(self, tensor: Tensor) -> None:
|
||||
_p_assert(
|
||||
tensor.device == torch.device("cpu"),
|
||||
f"Expects tensor to be on CPU but got {tensor.device}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_storage_freed(tensor: Tensor):
|
||||
def _check_storage_freed(tensor: Tensor) -> None:
|
||||
# Compile does not resize during trace
|
||||
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
|
||||
_p_assert(
|
||||
@ -2653,10 +2653,10 @@ class FlatParamHandle:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_storage_allocated(tensor: Tensor):
|
||||
def _check_storage_allocated(tensor: Tensor) -> None:
|
||||
_p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
|
||||
|
||||
def _check_low_precision_shard(self):
|
||||
def _check_low_precision_shard(self) -> None:
|
||||
_p_assert(
|
||||
self._uses_param_mixed_precision,
|
||||
"Not using low precision for parameters",
|
||||
@ -2671,7 +2671,7 @@ class FlatParamHandle:
|
||||
f"Expects the low precision shard to be on {self.device} but got {device}",
|
||||
)
|
||||
|
||||
def _check_unsharded(self, tensor: Tensor):
|
||||
def _check_unsharded(self, tensor: Tensor) -> None:
|
||||
msg_prefix = "Expects tensor to be unsharded "
|
||||
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||
unsharded_size = self.flat_param._unpadded_unsharded_size
|
||||
@ -2680,7 +2680,7 @@ class FlatParamHandle:
|
||||
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
|
||||
)
|
||||
|
||||
def _check_sharded(self, tensor: Tensor):
|
||||
def _check_sharded(self, tensor: Tensor) -> None:
|
||||
msg_prefix = "Expects tensor to be sharded "
|
||||
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
||||
@ -2746,7 +2746,7 @@ def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -
|
||||
|
||||
def _safe_setattr_tensor_or_param(
|
||||
module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
|
||||
):
|
||||
) -> None:
|
||||
# Call `delattr()` and `setattr()` to go through `nn.Module` checks
|
||||
if hasattr(module, param_name):
|
||||
delattr(module, param_name)
|
||||
@ -2804,19 +2804,19 @@ def _construct_padding_tensor(
|
||||
# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
|
||||
# message is passed in)
|
||||
@functools.lru_cache(1)
|
||||
def _warn_skip_writeback_check(log: logging.Logger, warning: str):
|
||||
def _warn_skip_writeback_check(log: logging.Logger, warning: str) -> None:
|
||||
logger.warning(warning)
|
||||
|
||||
|
||||
# Use `lru_cache(1)` to only log the warning once
|
||||
@functools.lru_cache(1)
|
||||
def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
|
||||
def _warn_use_fake_all_gather(log: logging.Logger, warning: str) -> None:
|
||||
logger.warning(warning)
|
||||
|
||||
|
||||
# Use `lru_cache(1)` to only log the warning once
|
||||
@functools.lru_cache(1)
|
||||
def _warn_use_fake_reduce(log: logging.Logger, warning: str):
|
||||
def _warn_use_fake_reduce(log: logging.Logger, warning: str) -> None:
|
||||
logger.warning(warning)
|
||||
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ class DefaultAllocMixin:
|
||||
|
||||
|
||||
class ProcessGroupAllocMixin:
|
||||
def __init__(self, group: dist.ProcessGroup, *args: Any, **kwargs: Any):
|
||||
def __init__(self, group: dist.ProcessGroup, *args: Any, **kwargs: Any) -> None:
|
||||
self._group = group
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()")
|
||||
@torch.library.impl(lib, "copy_", "HPU")
|
||||
@torch.library.impl(lib, "copy_", "CPU")
|
||||
@torch.library.impl(lib, "copy_", "MTIA")
|
||||
def copy_(tensor, data):
|
||||
def copy_(tensor, data) -> None:
|
||||
tensor.copy_(data)
|
||||
|
||||
|
||||
@ -130,7 +130,7 @@ But maybe there are few enough mutations induced by FSDP for this to matter.
|
||||
|
||||
|
||||
@torch.library.impl(lib, "copy_", "Functionalize")
|
||||
def copy__functionalize(tensor, data):
|
||||
def copy__functionalize(tensor, data) -> None:
|
||||
torch._sync(tensor)
|
||||
torch._sync(data)
|
||||
tensor_inner = torch._from_functional_tensor(tensor)
|
||||
@ -185,7 +185,7 @@ class ExtensionsData:
|
||||
# Save the all-gather input sizes to unflatten the all-gather outputs to ND
|
||||
all_gather_input_sizes: Sequence[torch.Size] = () # ND
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
self.all_gather_metadata = None
|
||||
self.all_gather_input_sizes = ()
|
||||
|
||||
@ -229,7 +229,7 @@ class FSDPParam:
|
||||
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
offload_policy: OffloadPolicy,
|
||||
):
|
||||
) -> None:
|
||||
self._module_info: ParamModuleInfo = module_info
|
||||
self.mesh_info = mesh_info
|
||||
self.post_forward_mesh_info = post_forward_mesh_info
|
||||
@ -262,7 +262,7 @@ class FSDPParam:
|
||||
param: nn.Parameter,
|
||||
device: torch.device,
|
||||
shard_placement_fn: Optional[Callable],
|
||||
):
|
||||
) -> None:
|
||||
if param.device != device and param.device.type != "meta":
|
||||
raise AssertionError(
|
||||
f"Expects the parameter to already be moved to device {device} but got {param.device}"
|
||||
@ -434,7 +434,7 @@ class FSDPParam:
|
||||
self.sharded_post_forward_size
|
||||
)
|
||||
|
||||
def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
|
||||
def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy) -> None:
|
||||
param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
|
||||
self.orig_dtype = self.sharded_param.dtype
|
||||
# Clamp `reduce_dtype` to `None` if no casting is required: since
|
||||
@ -469,7 +469,7 @@ class FSDPParam:
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
force_recreate: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if not force_recreate and len(self.all_gather_outputs) > 0:
|
||||
return # already initialized
|
||||
self.all_gather_outputs = [
|
||||
@ -477,7 +477,7 @@ class FSDPParam:
|
||||
for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
|
||||
]
|
||||
|
||||
def init_unsharded_param(self):
|
||||
def init_unsharded_param(self) -> None:
|
||||
"""
|
||||
[Note: Invariants for torch.compile Traceable FSDP2]
|
||||
1. Under compile, we always re-populate the content of `self._unsharded_param`
|
||||
@ -863,7 +863,7 @@ class FSDPParam:
|
||||
f"Expects to be in one of {states}, not {self.sharded_state}"
|
||||
)
|
||||
|
||||
def reset_sharded_param(self):
|
||||
def reset_sharded_param(self) -> None:
|
||||
# For ops like `nn.Module._apply` or `load_state_dict(assign=True)`
|
||||
# that change the sharded parameter tensor, we may need to re-pad the
|
||||
# sharded local tensor and re-save the reference.
|
||||
@ -930,7 +930,7 @@ class FSDPParam:
|
||||
)
|
||||
self._sharding_spec = self.sharded_param._spec
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})"
|
||||
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ reference to avoid holding onto memory after forward.
|
||||
class FSDPCommContext:
|
||||
"""This has the communication state shared across FSDP states/parameter groups."""
|
||||
|
||||
def lazy_init(self, device: torch.device):
|
||||
def lazy_init(self, device: torch.device) -> None:
|
||||
self.device_handle = _get_device_handle(device.type)
|
||||
# Setting the all-gather/reduce-scatter streams to be higher priority
|
||||
# can help avoid some issues where their copies in/out are delayed and
|
||||
@ -133,7 +133,7 @@ class FSDPParamGroup:
|
||||
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
offload_policy: OffloadPolicy,
|
||||
):
|
||||
) -> None:
|
||||
self.modules = modules # permit ref cycle because 1:1 lifetime
|
||||
param_module_infos = _get_param_module_infos(params, modules)
|
||||
|
||||
@ -248,7 +248,7 @@ class FSDPParamGroup:
|
||||
)
|
||||
self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
|
||||
|
||||
def lazy_init(self):
|
||||
def lazy_init(self) -> None:
|
||||
# Lazy init should be idempotent
|
||||
# Users may change or register parameters after construction time.
|
||||
# For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
|
||||
@ -300,7 +300,7 @@ class FSDPParamGroup:
|
||||
)
|
||||
|
||||
# Runtime #
|
||||
def unshard(self, async_op: bool = False):
|
||||
def unshard(self, async_op: bool = False) -> None:
|
||||
if self._all_gather_result is not None: # already called, pending wait
|
||||
return
|
||||
if self.is_unsharded:
|
||||
@ -344,7 +344,7 @@ class FSDPParamGroup:
|
||||
self._all_gather_comm,
|
||||
)
|
||||
|
||||
def wait_for_unshard(self):
|
||||
def wait_for_unshard(self) -> None:
|
||||
"""
|
||||
1. In forward with implicit prefetching, to overlap the current copy-out
|
||||
with the next all-gather, we save a reference to the current all-gather
|
||||
@ -419,14 +419,14 @@ class FSDPParamGroup:
|
||||
|
||||
self._all_gather_result = None # free unless saved in `all_gather_state`
|
||||
|
||||
def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]):
|
||||
def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]) -> None:
|
||||
# Calling `unshard` before lazy init means streams are not initialized
|
||||
if hasattr(self.comm_ctx, "all_gather_copy_in_stream") and event is not None:
|
||||
self.comm_ctx.all_gather_copy_in_stream.wait_event(event)
|
||||
if hasattr(self.comm_ctx, "all_gather_stream") and event is not None:
|
||||
self.comm_ctx.all_gather_stream.wait_event(event)
|
||||
|
||||
def reshard(self):
|
||||
def reshard(self) -> None:
|
||||
if self._training_state == TrainingState.FORWARD:
|
||||
if not self._reshard_after_forward:
|
||||
return
|
||||
@ -473,7 +473,7 @@ class FSDPParamGroup:
|
||||
self.comm_ctx.post_forward_order.append(self)
|
||||
self._post_forward_indices.append(post_forward_index)
|
||||
|
||||
def pre_backward(self, default_prefetch: bool, *unused: Any):
|
||||
def pre_backward(self, default_prefetch: bool, *unused: Any) -> None:
|
||||
if (
|
||||
compiled_autograd_enabled()
|
||||
and self._training_state == TrainingState.PRE_BACKWARD
|
||||
@ -492,7 +492,7 @@ class FSDPParamGroup:
|
||||
if default_prefetch and not compiled_autograd_enabled():
|
||||
self._backward_prefetch()
|
||||
|
||||
def post_backward(self, *unused: Any):
|
||||
def post_backward(self, *unused: Any) -> None:
|
||||
# This method should be idempotent and safe to call even when this
|
||||
# FSDP parameter group was not used in backward (should be a no-op)
|
||||
if not compiled_autograd_enabled():
|
||||
@ -601,7 +601,7 @@ class FSDPParamGroup:
|
||||
all_reduce_input, all_reduce_event
|
||||
)
|
||||
|
||||
def finalize_backward(self):
|
||||
def finalize_backward(self) -> None:
|
||||
self._wait_for_post_backward()
|
||||
for fsdp_param in self.fsdp_params:
|
||||
if fsdp_param.grad_offload_event is not None:
|
||||
@ -618,7 +618,7 @@ class FSDPParamGroup:
|
||||
self._all_gather_result = None
|
||||
self._post_forward_indices.clear()
|
||||
|
||||
def _wait_for_post_backward(self):
|
||||
def _wait_for_post_backward(self) -> None:
|
||||
if self._post_reduce_event is not None:
|
||||
self.device_handle.current_stream().wait_event(self._post_reduce_event)
|
||||
self._post_reduce_event = None
|
||||
@ -663,19 +663,19 @@ class FSDPParamGroup:
|
||||
target_fsdp_param_group.unshard(async_op)
|
||||
|
||||
# Utilities #
|
||||
def _to_sharded(self):
|
||||
def _to_sharded(self) -> None:
|
||||
if not self.is_sharded:
|
||||
for fsdp_param in self.fsdp_params:
|
||||
fsdp_param.to_sharded()
|
||||
self._sharded_state = ShardedState.SHARDED
|
||||
|
||||
def _to_sharded_post_forward(self):
|
||||
def _to_sharded_post_forward(self) -> None:
|
||||
if not self.is_sharded_post_forward:
|
||||
for fsdp_param in self.fsdp_params:
|
||||
fsdp_param.to_sharded_post_forward()
|
||||
self._sharded_state = ShardedState.SHARDED_POST_FORWARD
|
||||
|
||||
def _to_unsharded(self):
|
||||
def _to_unsharded(self) -> None:
|
||||
if not self.is_unsharded:
|
||||
for fsdp_param in self.fsdp_params:
|
||||
fsdp_param.to_unsharded()
|
||||
@ -806,10 +806,10 @@ class FSDPParamGroup:
|
||||
return f"{label} ({self._module_fqn})"
|
||||
return label
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"FSDPParamGroup(fqn={self._module_fqn})"
|
||||
|
||||
def _validate_no_meta_params(self):
|
||||
def _validate_no_meta_params(self) -> None:
|
||||
param_names_on_meta = [
|
||||
fsdp_param._param_fqn
|
||||
for fsdp_param in self.fsdp_params
|
||||
@ -823,7 +823,7 @@ class FSDPParamGroup:
|
||||
"call module.reset_parameters() on each module to initialize values."
|
||||
)
|
||||
|
||||
def _validate_cpu_offload_params(self):
|
||||
def _validate_cpu_offload_params(self) -> None:
|
||||
if not isinstance(self.offload_policy, CPUOffloadPolicy):
|
||||
return
|
||||
fsdp_params_not_on_cpu = [
|
||||
@ -873,7 +873,7 @@ def _get_param_module_infos(
|
||||
|
||||
class RegisterPostBackwardFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def _assert_not_tracing_fsdp():
|
||||
def _assert_not_tracing_fsdp() -> None:
|
||||
if compiled_autograd_enabled():
|
||||
# TODO: Find a way to print the offending FSDP2 module.
|
||||
msg = """\
|
||||
|
||||
@ -347,7 +347,7 @@ class FSDPState(_State):
|
||||
t.register_hook(self._pre_backward)
|
||||
return output
|
||||
|
||||
def _register_root_post_backward_final_callback(self):
|
||||
def _register_root_post_backward_final_callback(self) -> None:
|
||||
if self._state_ctx.post_backward_final_callback_queued:
|
||||
return
|
||||
self._state_ctx.post_backward_final_callback_queued = True
|
||||
|
||||
@ -632,7 +632,7 @@ def _init_param_handle_from_params(
|
||||
state: _FSDPState,
|
||||
params: list[nn.Parameter],
|
||||
fully_sharded_module: nn.Module,
|
||||
):
|
||||
) -> None:
|
||||
if len(params) == 0:
|
||||
return
|
||||
handle = FlatParamHandle(
|
||||
@ -905,7 +905,7 @@ def _materialize_meta_module(
|
||||
device_from_device_id: Optional[torch.device],
|
||||
ignored_modules: set[nn.Module],
|
||||
device_handle: _FSDPDeviceHandle,
|
||||
):
|
||||
) -> None:
|
||||
# Run default meta device initialization
|
||||
materialization_device = device_from_device_id or torch.device(
|
||||
device_handle.current_device()
|
||||
@ -1046,7 +1046,7 @@ def _move_states_to_device(
|
||||
_warn_cpu_init()
|
||||
|
||||
|
||||
def _warn_cpu_init():
|
||||
def _warn_cpu_init() -> None:
|
||||
warnings.warn(
|
||||
"The passed-in `module` is on CPU and will thus have FSDP's sharding "
|
||||
"initialization run on CPU, which may be slower than on GPU. We "
|
||||
|
||||
@ -1051,7 +1051,7 @@ def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]:
|
||||
|
||||
"""
|
||||
|
||||
def module_fn(module, prefix, tree_level, flat_param_to_fqn):
|
||||
def module_fn(module, prefix, tree_level, flat_param_to_fqn) -> None:
|
||||
for param_name, param in _named_parameters_with_duplicates(
|
||||
module, recurse=False
|
||||
):
|
||||
@ -2077,7 +2077,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]:
|
||||
parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters.
|
||||
"""
|
||||
|
||||
def module_fn(module, prefix, tree_level, fqn_to_param_info):
|
||||
def module_fn(module, prefix, tree_level, fqn_to_param_info) -> None:
|
||||
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
|
||||
if fsdp_state is None:
|
||||
return
|
||||
|
||||
@ -143,7 +143,7 @@ def _lazy_init(
|
||||
return state
|
||||
|
||||
|
||||
def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
|
||||
def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module) -> None:
|
||||
"""
|
||||
Checks that all ``FlatParameter``s in ``module`` 's tree managed by
|
||||
``state`` are on the expected device for *lazy initialization*.
|
||||
@ -311,7 +311,7 @@ def _reshard(
|
||||
state: _FSDPState,
|
||||
handle: FlatParamHandle,
|
||||
free_unsharded_flat_param: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
|
||||
free the handle's padded unsharded flat parameter.
|
||||
@ -703,7 +703,7 @@ def _post_backward_hook(
|
||||
handle: FlatParamHandle,
|
||||
flat_param,
|
||||
*unused: Any,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
|
||||
|
||||
@ -955,7 +955,7 @@ def _post_reduce_grad_callback(
|
||||
handle: FlatParamHandle,
|
||||
# Additional arguments needed for the callback logic
|
||||
grad_to_offload: torch.Tensor,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
This callback captures any logic to run after the gradient reduction
|
||||
finishes. Currently, this offloads the gradient to CPU if CPU offloading is
|
||||
@ -970,7 +970,7 @@ def _offload_grad(
|
||||
state: _FSDPState,
|
||||
handle: FlatParamHandle,
|
||||
grad_to_offload: torch.Tensor,
|
||||
):
|
||||
) -> None:
|
||||
if not handle._offload_params:
|
||||
return
|
||||
# Offload the gradient to CPU to ensure parameters and gradients are on the
|
||||
@ -992,7 +992,7 @@ def _offload_grad(
|
||||
|
||||
|
||||
@no_type_check
|
||||
def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):
|
||||
def _post_backward_use_sharded_grad_views(handle: FlatParamHandle) -> None:
|
||||
if not handle._use_orig_params:
|
||||
return
|
||||
# Since the handle's `FlatParameter` completed its gradient computation, we
|
||||
@ -1031,7 +1031,7 @@ def _cast_grad_to_param_dtype(
|
||||
state: _FSDPState,
|
||||
sharded_grad: torch.Tensor,
|
||||
param: FlatParameter,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Casts ``sharded_grad`` back to the full parameter dtype so that the
|
||||
optimizer step runs with that dtype. This performs an actual cast if
|
||||
@ -1084,7 +1084,7 @@ def _low_precision_hook_enabled(state: _FSDPState) -> bool:
|
||||
def _post_backward_final_callback(
|
||||
state: _FSDPState,
|
||||
module: nn.Module,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
This waits for the post-backward to finish and performs some final cleanup.
|
||||
This runs at the end of the entire backward pass and should only be called
|
||||
@ -1346,7 +1346,7 @@ def _register_post_forward_hook(
|
||||
def _register_root_pre_forward_hook(
|
||||
state: _FSDPState,
|
||||
module: nn.Module,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Registers root pre-forward hook on ``module``, which should be the local
|
||||
FSDP root.
|
||||
@ -1544,7 +1544,7 @@ def _wait_for_computation_stream(
|
||||
computation_stream: torch.Stream,
|
||||
unshard_stream: torch.Stream,
|
||||
pre_unshard_stream: torch.Stream,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Has the unshard and pre-unshard streams wait for the computation stream.
|
||||
For example, this should be called in the FSDP root's pre-forward to
|
||||
@ -1562,7 +1562,7 @@ def _wait_for_computation_stream(
|
||||
|
||||
def _reset_flat_param_grad_info_if_needed(
|
||||
handles: list[FlatParamHandle],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Clears the original parameters' gradients if needed. This method's CPU
|
||||
overhead is minimal, so we may call it throughout FSDP methods, which serve
|
||||
|
||||
@ -18,7 +18,7 @@ from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
||||
|
||||
|
||||
def _get_remote_device_str(rank, device_type, num_devices_per_node):
|
||||
def _get_remote_device_str(rank, device_type, num_devices_per_node) -> str:
|
||||
if device_type.lower() == "cpu":
|
||||
return f"rank:{rank}/{device_type}"
|
||||
elif device_type.lower() == "hpu":
|
||||
|
||||
@ -542,7 +542,7 @@ def _sharded_post_state_dict_hook(
|
||||
with a unflattened, sharded parameter (a ShardedTensor).
|
||||
"""
|
||||
|
||||
def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str):
|
||||
def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str) -> None:
|
||||
param = state_dict[fqn]
|
||||
if not fsdp_state._state_dict_config._use_dtensor:
|
||||
sharded_tensor = _ext_chunk_tensor(
|
||||
@ -895,7 +895,7 @@ def _post_load_state_dict_hook(
|
||||
SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")
|
||||
|
||||
|
||||
def _register_all_state_dict_hooks(state: _FSDPState):
|
||||
def _register_all_state_dict_hooks(state: _FSDPState) -> None:
|
||||
"""
|
||||
Registers pre-save, post-save, pre-load, and post-load state dict hooks.
|
||||
"""
|
||||
|
||||
@ -35,7 +35,7 @@ FLAT_PARAM = "_flat_param"
|
||||
def _writeback_to_local_shard(
|
||||
handle: FlatParamHandle,
|
||||
writeback_grad: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
For the handle, writes back the this rank's shard of the unsharded
|
||||
flattened parameter to the sharded flattened parameter. If
|
||||
|
||||
@ -30,7 +30,7 @@ def _auto_wrap(
|
||||
ignored_params: set[nn.Parameter],
|
||||
root_kwargs: dict[str, Any],
|
||||
fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Auto wraps modules in ``root_module`` 's tree according to ``policy``
|
||||
following a post-order traversal.
|
||||
@ -102,7 +102,7 @@ def _auto_wrap(
|
||||
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _check_nested_wrapping(root_module: nn.Module):
|
||||
def _check_nested_wrapping(root_module: nn.Module) -> None:
|
||||
for module_name, module in root_module.named_modules():
|
||||
if _get_module_fsdp_state(module) is not None:
|
||||
raise ValueError(
|
||||
@ -113,7 +113,7 @@ def _check_nested_wrapping(root_module: nn.Module):
|
||||
|
||||
def _warn_on_overridden_mixed_precision(
|
||||
overridden_module_classes: set[type[nn.Module]],
|
||||
):
|
||||
) -> None:
|
||||
if len(overridden_module_classes) == 0:
|
||||
return
|
||||
warnings.warn(
|
||||
@ -130,7 +130,7 @@ def _validate_frozen_params(
|
||||
modules_to_wrap: set[nn.Module],
|
||||
ignored_params: set[nn.Parameter],
|
||||
use_orig_params: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
This checks that, given ``modules_to_wrap``, each module would manage
|
||||
parameters that are uniformly frozen or non-frozen. This uniformity
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user