Compare commits

...

5 Commits

Author SHA1 Message Date
8d4e1101c2 More fixes
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-10 08:36:19 +08:00
9074ffeb40 More fixes
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-10 08:36:19 +08:00
08537982e6 More fixes
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-10 08:36:19 +08:00
816779ad01 More fixes
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-10 08:36:19 +08:00
69f8d844ba Add return types of Python functions
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-10 08:36:16 +08:00
157 changed files with 543 additions and 501 deletions

View File

@ -73,7 +73,7 @@ if is_available():
_DistributedPdb().set_trace()
"""
def interaction(self, *args, **kwargs):
def interaction(self, *args, **kwargs) -> None:
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
@ -83,7 +83,7 @@ if is_available():
_breakpoint_cache: dict[int, typing.Any] = {}
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600) -> None:
"""
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
done with the breakpoint before continuing.

View File

@ -16,7 +16,7 @@ _T = TypeVar("_T", covariant=True)
_P = ParamSpec("_P")
def generate_state_key(string="__composable_api_state_key"):
def generate_state_key(string: str = "__composable_api_state_key") -> str:
return f"{string}_{str(uuid.uuid4())}"
@ -183,7 +183,9 @@ def contract(
f"Outputs: {num_new_modules} modules"
)
def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
def check_fqn(
orig_fqns: list[str], new_fqns: list[str], check_key: str
) -> None:
if orig_fqns == new_fqns:
return

View File

@ -27,8 +27,7 @@ except Exception:
stacklevel=2,
)
def is_torchdynamo_compiling(): # type: ignore[misc]
return False
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False

View File

@ -316,7 +316,7 @@ class LocalIntNode:
return ConstantIntNode(next(iter(local_ints.values())))
return super().__new__(cls)
def __init__(self, local_ints: dict[int, int]):
def __init__(self, local_ints: dict[int, int]) -> None:
self._local_ints = local_ints
def maybe_as_int(self) -> Optional[int]:
@ -590,7 +590,7 @@ class LocalTensor(torch.Tensor):
@torch._disable_dynamo
@mark_subclass_constructor_exportable_experimental # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor":
@ -765,7 +765,7 @@ class LocalTensorMode(TorchDispatchMode):
"""
# What ranks this local tensor mode is operating over
def __init__(self, ranks: Union[int, frozenset[int]]):
def __init__(self, ranks: Union[int, frozenset[int]]) -> None:
if isinstance(ranks, int):
# assume is world size
self.ranks = frozenset(range(ranks))
@ -1087,7 +1087,7 @@ class LocalRunnerMode:
def __init__(
self, ranks: frozenset[int] | int, concurrency: int, fn: Callable[[int], None]
):
) -> None:
if isinstance(ranks, int):
ranks = frozenset(range(ranks))
self._ranks = ranks

View File

@ -80,7 +80,7 @@ def shard_parameter(
sharding_spec: ShardingSpec,
src_rank=0,
process_group=None,
):
) -> None:
"""
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
module, it shards that parameter according to the provided
@ -223,7 +223,9 @@ def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
return module
def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None):
def shard_module(
module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None
) -> None:
"""
Shards a given module according to the provided sharding `plan`. This method
first shards all the parameters according to the given sharding `plan`. Then if

View File

@ -5,7 +5,7 @@ import torch
from torch.utils import _pytree as pytree
def _basic_validation(op, args=(), kwargs=None):
def _basic_validation(op, args=(), kwargs=None) -> None:
"""
Common validation across all ops go in here.
"""
@ -17,7 +17,7 @@ def _basic_validation(op, args=(), kwargs=None):
# Validate types
has_distributed_tensor = False
def is_distributed_tensor(e):
def is_distributed_tensor(e) -> None:
nonlocal has_distributed_tensor
if isinstance(e, ShardedTensor):
has_distributed_tensor = True
@ -34,7 +34,7 @@ def _basic_validation(op, args=(), kwargs=None):
# Validate all distributed tensors use the same PG.
cur_pg: Optional[torch.distributed.ProcessGroup] = None
def validate_pg(e):
def validate_pg(e) -> None:
nonlocal cur_pg
if isinstance(e, ShardedTensor):
if cur_pg is not None and e._process_group is not cur_pg:
@ -48,7 +48,7 @@ def _basic_validation(op, args=(), kwargs=None):
pytree.tree_map_(validate_pg, kwargs)
def _register_default_op(op, decorator):
def _register_default_op(op, decorator) -> None:
@decorator(op)
def tensor_default_op(types, args=(), kwargs=None, pg=None):
"""

View File

@ -34,7 +34,7 @@ class ShardMetadata:
shard_offsets: list[int],
shard_sizes: list[int],
placement: Optional[Union[str, _remote_device]] = None,
):
) -> None:
self.shard_offsets = shard_offsets
self.shard_sizes = shard_sizes
if isinstance(placement, str):

View File

@ -11,7 +11,7 @@ and PartialTensor.
"""
def _register_op(op, func, op_table):
def _register_op(op, func, op_table) -> None:
"""
Performs basic validation and registers the provided op in the given
op_table.

View File

@ -409,7 +409,7 @@ def init_from_local_shards(
)
def state_dict_hook(module, destination, prefix, local_metadata):
def state_dict_hook(module, destination, prefix, local_metadata) -> None:
"""
Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
registered to the Module using
@ -432,7 +432,7 @@ def pre_load_state_dict_hook(
missing_keys,
unexpected_keys,
error_msgs,
):
) -> None:
"""
Pre-load state dict hook to add ShardedTensor to the module.
"""

View File

@ -62,7 +62,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
def _register_sharded_op_on_local_shards(
op, early_stop_func=None, extra_check=None, customized_func=None
):
) -> None:
"""
Handles ``__torch_function__`` dispatch for ops which are performed on
each shard of the sharded tensor such as elementwise op like

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import torch
import torch.distributed._shard.sharded_tensor as sharded_tensor
from torch.distributed._shard.sharded_tensor import _sharded_op_impl
@ -135,7 +136,7 @@ tensor_like_creation_op_map = {
# tensor ops that behave the same as the default tensor
def register_tensor_creation_op(op):
def register_tensor_creation_op(op) -> None:
@_sharded_op_impl(op)
def tensor_creation_op(types, args=(), kwargs=None, pg=None):
"""

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Literal
import torch
from torch.distributed._shard.sharded_tensor import _sharded_op_impl
@ -8,5 +10,7 @@ from torch.distributed._shard.sharded_tensor import _sharded_op_impl
# the future behavior of overwriting the existing tensor
# instead of doing in-place change using `.data = `.
@_sharded_op_impl(torch._has_compatible_shallow_copy_type)
def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None):
def tensor_has_compatible_shallow_copy_type(
types, args=(), kwargs=None, pg=None
) -> Literal[False]:
return False

View File

@ -61,7 +61,7 @@ def st_is_meta(types, args=(), kwargs=None, pg=None):
return args[0].local_tensor().is_meta
def sharded_type_as_check(*args, **kwargs):
def sharded_type_as_check(*args, **kwargs) -> None:
"""
Perform extra checks for the sharded_type_as op such as the input needs to
be either a Tensor or ShardedTensor.

View File

@ -91,7 +91,7 @@ def _flatten_tensor_size(size) -> torch.Size:
return torch.Size(dims)
def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True) -> None:
if is_local:
assert isinstance(ranks, int)
if expected != actual:

View File

@ -5,7 +5,9 @@ from typing import Optional
from torch.distributed._shard.metadata import ShardMetadata
def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata):
def _check_shard_metadata_pair_overlap(
shard1: ShardMetadata, shard2: ShardMetadata
) -> bool:
"""
Checks if two shards overlap.
"""
@ -70,7 +72,7 @@ def _find_1d_overlapping_shards(
return None
def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]):
def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]) -> None:
"""
Ensures none of the shards overlap with each other.

View File

@ -18,7 +18,7 @@ from torch.distributed.nn.functional import (
)
def _chunk_sharding_spec_check(spec, op):
def _chunk_sharding_spec_check(spec, op) -> None:
"""
For the given op implementation check if the sharding spec is ChunkShardingSpec.
"""
@ -30,7 +30,7 @@ def _chunk_sharding_spec_check(spec, op):
def _register_sharded_op_on_local_tensor(
op, early_stop_func=None, extra_check=None, customized_func=None
):
) -> None:
"""
Handles ``__torch_function__`` dispatch for ops which are performed on
the single local tensor of the sharded tensor such as op like

View File

@ -128,7 +128,7 @@ def sharded_embedding(types, args, kwargs, pg):
)
def _validate_embedding_param(args, kwargs):
def _validate_embedding_param(args, kwargs) -> None:
"""
Validate input params of sharded embedding op.

View File

@ -150,7 +150,7 @@ def sharded_embedding_bag(types, args, kwargs, pg):
)
def _validate_embedding_bag_param(args, kwargs):
def _validate_embedding_bag_param(args, kwargs) -> None:
"""
Validate input params of sharded embeddingBag op.

View File

@ -113,7 +113,7 @@ class _ModMemStats:
values as the memory consumed in bytes.
"""
def __init__(self, mod_fqn: str):
def __init__(self, mod_fqn: str) -> None:
self.mod_fqn = mod_fqn
self.parameter_mem: int
self.buffer_mem: int

View File

@ -157,7 +157,7 @@ class MemoryTracker:
def show_traces(self, path: str = "") -> None:
import matplotlib.pyplot as plt
def _plot_figure(x, y_values, labels):
def _plot_figure(x, y_values, labels) -> None:
min_val = min(chain.from_iterable(y_values)) * 0.999
max_val = max(chain.from_iterable(y_values)) * 1.001
plt.figure()

View File

@ -55,7 +55,7 @@ class ModTracker:
A Set containing the fqn for each module currently running their forward
"""
def __init__(self):
def __init__(self) -> None:
self.parents = {"Global"}
self._active_module_cnt = {}
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
@ -67,7 +67,7 @@ class ModTracker:
self._user_pre_bw_hook = None
self._user_post_bw_hook = None
def _maybe_set_engine_callback(self):
def _maybe_set_engine_callback(self) -> None:
# This assumes no concurrent calls to backward
if self._has_callback:
return
@ -76,7 +76,7 @@ class ModTracker:
torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback)
self._post_bw_callbacks_to_enqueue.clear()
def callback():
def callback() -> None:
self.parents = {"Global"}
self._has_callback = False
@ -102,7 +102,7 @@ class ModTracker:
post_fw_hook: Optional[Callable] = None,
pre_bw_hook: Optional[Callable] = None,
post_bw_hook: Optional[Callable] = None,
):
) -> None:
"""
Registers user-specified hooks to be called before/after the forward/backward pass for each
module tracked by the ``ModTracker``. One or more can be ``None``.
@ -149,7 +149,7 @@ class ModTracker:
post_bw_hook, self._user_post_bw_hook, "post_bw_hook"
)
def clear_user_hooks(self):
def clear_user_hooks(self) -> None:
"""
Clears the user specified hooks registered with ``register_user_hooks``
"""
@ -170,12 +170,14 @@ class ModTracker:
return mod_name
def _get_append_fn(self, w_mod, name, is_bw):
def fn(*args):
def fn(*args) -> None:
if is_bw:
self._maybe_set_engine_callback()
if name in self.parents and not self.is_bw:
def custom_formatwarning(msg, category, filename, lineno, line=None):
def custom_formatwarning(
msg, category, filename, lineno, line=None
) -> str:
return f"{filename}:{lineno}: {category.__name__}: {msg} \n"
# pyrefly: ignore [bad-assignment]
@ -197,7 +199,7 @@ class ModTracker:
return fn
def _get_pop_fn(self, w_mod, name, is_bw):
def fn(*args):
def fn(*args) -> None:
if self._user_post_bw_hook is not None and is_bw:
self._user_post_bw_hook(w_mod(), args)
if name in self.parents:
@ -213,7 +215,7 @@ class ModTracker:
return fn
def _fw_pre_hook(self, mod, input):
def _fw_pre_hook(self, mod, input) -> None:
name = self._get_mod_name(mod)
w_mod = weakref.ref(mod)
self._get_append_fn(w_mod, name, False)()
@ -229,7 +231,7 @@ class ModTracker:
self._get_pop_fn(w_mod, name, True)
)
def _fw_post_hook(self, mod, input, output):
def _fw_post_hook(self, mod, input, output) -> None:
name = self._get_mod_name(mod)
w_mod = weakref.ref(mod)
if self._user_post_fw_hook is not None:

View File

@ -21,7 +21,7 @@ class DefaultState:
"gradient_postdivide_factor",
]
def __init__(self, process_group: dist.ProcessGroup):
def __init__(self, process_group: dist.ProcessGroup) -> None:
if process_group is None:
raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
self.process_group = process_group
@ -64,12 +64,12 @@ class LowPrecisionState(DefaultState):
self,
process_group,
parameter_type=torch.float32,
):
) -> None:
super().__init__(process_group)
self.parameter_type = parameter_type
def _decompress(state: LowPrecisionState, grad: torch.Tensor):
def _decompress(state: LowPrecisionState, grad: torch.Tensor) -> None:
"""
Casts gradients back to full parameter precision so that further computation happens in full precision.
"""
@ -92,7 +92,7 @@ def _decompress(state: LowPrecisionState, grad: torch.Tensor):
orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type]
def allreduce_hook(state: DefaultState, grad: torch.Tensor):
def allreduce_hook(state: DefaultState, grad: torch.Tensor) -> None:
r"""
Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients.
@ -111,7 +111,9 @@ def allreduce_hook(state: DefaultState, grad: torch.Tensor):
grad.div_(state.gradient_postdivide_factor)
def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor):
def reduce_scatter_hook(
state: DefaultState, grad: torch.Tensor, output: torch.Tensor
) -> None:
r"""
Implement the FSDP communication hook for ``reduce_scatter`` algorithm.
@ -137,7 +139,7 @@ def _low_precision_hook(
state: LowPrecisionState,
grad: torch.Tensor,
output: Optional[torch.Tensor],
):
) -> None:
if grad.dtype != prec:
grad.data = grad.data.to(prec)
if output is not None:

View File

@ -65,7 +65,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
self._opt_hook_state = _OptimizerHookState(f_optim, params)
def register_ddp(self, ddp_inst: DistributedDataParallel):
def register_ddp(self, ddp_inst: DistributedDataParallel) -> None:
# NOTE: using a custom communication hook and fused optimizer is not
# yet supported.
ddp_inst.register_comm_hook( # type: ignore[operator]

View File

@ -106,7 +106,7 @@ def auto_quantize(func, qtype, quant_loss=None):
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> None:
group = kwargs.get("group")
async_op = kwargs.get("async_op", False)
if async_op is True:

View File

@ -30,7 +30,7 @@ from . import (
__all__ = ["DDPCommHookType", "register_ddp_comm_hook"]
def _ddp_comm_hook_wrapper(comm_hook, model, state):
def _ddp_comm_hook_wrapper(comm_hook, model, state) -> None:
model.register_comm_hook(state, comm_hook)
@ -40,7 +40,7 @@ def _powerSGD_comm_hook_wrapper(
state,
matrix_approximation_rank,
start_powerSGD_iter=1_000,
):
) -> None:
"""
Wrap PowerSGD communication hook.
@ -123,7 +123,7 @@ class DDPCommHookType(Enum):
)
def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None):
def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None) -> None:
"""
Register ``ddp_comm_hooks`` to DDP model.

View File

@ -23,7 +23,7 @@ def _perform_local_step(
bucket: dist.GradBucket,
zero: ZeroRedundancyOptimizer,
rank: int,
):
) -> None:
r"""
Perform a local optimizer step using the gradients provided by ``bucket``.
@ -67,7 +67,7 @@ def _perform_local_step(
def _broadcast_bucket(
bucket_index: int,
zero: ZeroRedundancyOptimizer,
):
) -> None:
r"""
Broadcasts a bucket's parameters.
@ -104,7 +104,7 @@ def _broadcast_bucket(
def _save_ddp_bucket_info(
bucket: dist.GradBucket,
zero: ZeroRedundancyOptimizer,
):
) -> None:
r"""
Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``.
@ -136,7 +136,7 @@ def _hook_with_zero_step_setup(
ddp_ref: weakref.ReferenceType,
zero: ZeroRedundancyOptimizer,
bucket: dist.GradBucket,
):
) -> None:
r"""
Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.

View File

@ -60,7 +60,7 @@ def _reducer_allreduce_and_upcast_hook(
p.grad.data = p.grad.to(p.data.dtype)
# enqueue a callback to wait for this stream at end of backward
def wait_for_stream_cb():
def wait_for_stream_cb() -> None:
torch.accelerator.current_stream().wait_stream(stream)
# Remove post-backward hooks since they are re-installed in next
# iteration, similar to FSDP.

View File

@ -23,16 +23,16 @@ class _OptimizerHookState:
__slots__ = ["functional_optimizer", "params_to_optimize"]
def __init__(self, functional_optim, params=None):
def __init__(self, functional_optim, params=None) -> None:
self.functional_optimizer = functional_optim
self._check_valid_functional_optim()
self._set_params_to_optimize(params)
def _set_params_to_optimize(self, params):
def _set_params_to_optimize(self, params) -> None:
if params is not None:
self.params_to_optimize = set(params)
def _check_valid_functional_optim(self):
def _check_valid_functional_optim(self) -> None:
if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME):
raise ValueError(
f"Class {type(self.functional_optimizer)} must implement method "
@ -99,7 +99,7 @@ def _apply_optim_in_backward_hook(
# enqueue a callback to wait for this optimizer stream at the end of
# backward and set all DDP managed grads to None.
def wait_for_optim_stream_callback():
def wait_for_optim_stream_callback() -> None:
torch.accelerator.current_stream().wait_stream(
optim_stream_state.optim_stream
)

View File

@ -38,7 +38,7 @@ class PostLocalSGDState:
subgroup,
start_localSGD_iter,
post_local_gradient_allreduce=True,
):
) -> None:
"""Initialize state object with given parameters and log when localSGD start."""
logger.info(
"Local SGD will be started after %s iterations", start_localSGD_iter
@ -55,7 +55,7 @@ class PostLocalSGDState:
# Iteration/step in the training loop.
self.iter = 0
def maybe_increase_iter(self, bucket):
def maybe_increase_iter(self, bucket) -> None:
"""Track iterations and trigger log message at start of local SGD."""
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.

View File

@ -16,7 +16,7 @@ __all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"]
logger = logging.getLogger(__name__)
def _orthogonalize(matrices, epsilon=0):
def _orthogonalize(matrices, epsilon=0) -> None:
"""
Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
@ -41,7 +41,7 @@ def _orthogonalize(matrices, epsilon=0):
)
def _orthogonalize_gram_schmidt(matrices, epsilon=0):
def _orthogonalize_gram_schmidt(matrices, epsilon=0) -> None:
"""
Apply Gram-Schmidt procedure to orthogonalize a batch of matrices.
@ -103,7 +103,7 @@ def _should_compress(
)
def _report_compression_stats(bucket, state):
def _report_compression_stats(bucket, state) -> None:
"""Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state."""
if bucket.is_last() and state.iter >= state.next_stats_report:
stats = state.compression_stats()
@ -188,7 +188,7 @@ class PowerSGDState:
random_seed=0,
compression_stats_logging_frequency=10_000,
batch_tensors_with_same_shape: bool = False,
):
) -> None:
logger.info(
"PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; "
"min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; "
@ -302,7 +302,7 @@ class PowerSGDState:
for slot, value in state.items():
setattr(self, slot, value)
def maybe_increase_iter(self, bucket):
def maybe_increase_iter(self, bucket) -> None:
"""Track iterations and trigger log message at start of local SGD."""
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.

View File

@ -99,7 +99,9 @@ class HierarchicalModelAverager(averagers.ModelAverager):
`HierarchicalModelAverager` is experimental and subject to change.
"""
def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
def __init__(
self, period_group_size_dict=None, warmup_steps=0, process_group=None
) -> None:
super().__init__(process_group)
if not period_group_size_dict:
raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
@ -163,7 +165,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
):
) -> None:
"""
Averages parameters or parameter groups of an optimizer.

View File

@ -21,7 +21,7 @@ __all__ = [
def average_parameters(
params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
):
) -> None:
"""
Averages all the given parameters.
@ -87,6 +87,6 @@ def average_parameters_or_parameter_groups(
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
process_group: ProcessGroup,
):
) -> None:
"""Averages parameters of a model or parameter groups of an optimizer."""
average_parameters(iter(get_params_to_average(params)), process_group)

View File

@ -46,7 +46,7 @@ class HybridModel(torch.nn.Module):
servers.
"""
def __init__(self, emb_rref_list, device):
def __init__(self, emb_rref_list, device) -> None:
super().__init__()
self.emb_rref_list = emb_rref_list
fc1 = torch.nn.Linear(512, 256)
@ -85,7 +85,7 @@ def _retrieve_embedding_parameters(emb_rref):
return [RRef(p) for p in emb_rref.local_value().parameters()]
def _print_header():
def _print_header() -> None:
_print_cont("\n")
_print_cont(" " * 10)
for _ in [50, 75, 90, 95]:
@ -93,7 +93,7 @@ def _print_header():
_print_cont("\n")
def _print_benchmark(prefix, nelem, measurements):
def _print_benchmark(prefix, nelem, measurements) -> None:
measurements = sorted(measurements)
_print_cont(f"{prefix:8s}:")
for p in [50, 75, 90, 95]:
@ -102,7 +102,7 @@ def _print_benchmark(prefix, nelem, measurements):
_print_cont("\n")
def _print_cont(msg):
def _print_cont(msg) -> None:
print(msg, end="", flush=True)
@ -201,7 +201,7 @@ def _run_trainer(emb_rref_list, rank):
return rank, measurements, batch_size # type: ignore[possibly-undefined]
def run_worker(rank, world_size):
def run_worker(rank, world_size) -> None:
r"""
Initialize RPC, calls the function, and shuts down RPC.
"""

View File

@ -34,7 +34,7 @@ class _CheckpointRequestIdentifier:
checkpoint_id: Union[str, os.PathLike, None]
uuid: str
def __init__(self, checkpoint_id: Union[str, os.PathLike, None]):
def __init__(self, checkpoint_id: Union[str, os.PathLike, None]) -> None:
self.checkpoint_id = checkpoint_id
self.uuid = str(uuid4())
@ -58,7 +58,7 @@ class _ProcessGroupInitInfo:
tcp_store_master_port: int
use_prefix_store: bool
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
self.local_rank = dist.get_node_local_rank(fallback_rank=0)
self.global_rank = dist.get_rank(process_group)
self.world_size = dist.get_world_size(process_group)
@ -103,7 +103,7 @@ class _AsyncCheckpointProcess:
def __init__(
self,
pg_init_info: _ProcessGroupInitInfo,
):
) -> None:
self.ctx = mp.get_context("spawn")
self._process_pipe, child_end = self.ctx.Pipe()

View File

@ -39,7 +39,7 @@ class _Checkpointer:
no_dist: bool = False,
load_planner: Optional[LoadPlanner] = None,
save_planner: Optional[SavePlanner] = None,
):
) -> None:
"""Initializes the Checkpointer instance.
Args:

View File

@ -105,7 +105,7 @@ class Barrier(abc.ABC):
"""
@abc.abstractmethod
def __init__(self, **kwargs: dict[str, Any]):
def __init__(self, **kwargs: dict[str, Any]) -> None:
"""
Initialize a barrier.
@ -185,7 +185,7 @@ class TCPStoreBarrier(Barrier):
tcpstore_port: int,
master_address: str,
timeout_secs: int,
):
) -> None:
"""
Initialize a TCPStoreBarrier.

View File

@ -74,7 +74,7 @@ class CheckpointProcess:
subprocess_init_args: tuple[Any, ...],
checkpoint_writer_init_fn: Callable[..., CheckpointWriter],
checkpoint_writer_init_args: dict[str, Any],
):
) -> None:
self._executor = ThreadPoolExecutor(max_workers=1)
self._rank_info = rank_info
self._config = config

View File

@ -32,7 +32,7 @@ class CheckpointReader:
def __init__(
self,
rank_info: RankInfo,
):
) -> None:
"""
Initialize a CheckpointReader.

View File

@ -74,7 +74,7 @@ class CheckpointWriter:
rank_info: RankInfo,
barrier: Optional[Barrier] = None,
commit_hook: Optional[WriterHook] = None,
):
) -> None:
"""
Initialize a CheckpointWriter.

View File

@ -110,7 +110,7 @@ class SyncCheckpointer(Checkpointer):
self,
writer: CheckpointWriter,
reader: CheckpointReader,
):
) -> None:
"""
Initialize a synchronous checkpointer.
@ -225,7 +225,7 @@ class AsyncCheckpointer(Checkpointer):
checkpoint_stager: CheckpointStager,
checkpoint_process: CheckpointProcess,
reader: CheckpointReader,
):
) -> None:
"""
Initialize an asynchronous checkpointer.

View File

@ -138,7 +138,7 @@ class DefaultStager(CheckpointStager):
def __init__(
self,
config: CheckpointStagerConfig = CheckpointStagerConfig(),
):
) -> None:
self._config = config
self._state_dict_stager = StateDictStager(
pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory

View File

@ -32,7 +32,7 @@ class StateDictStager:
pin_memory: bool = False,
share_memory: bool = False,
pin_memory_min_bytes: int = 5,
):
) -> None:
if pin_memory and not torch.cuda.is_available():
warnings.warn(
"Ignoring pin_memory flag for checkpoint staging as pinning memory"
@ -258,7 +258,7 @@ class StateDictStager:
return y
def close(self):
def close(self) -> None:
"""
Clean up all cached storages and release associated resources.
@ -386,7 +386,7 @@ class StateDictStager:
self._keep_alive(x, memo) # Make sure x lives at least as long as d
return y
def _keep_alive(self, x, memo):
def _keep_alive(self, x, memo) -> None:
"""Keeps a reference to the object x in the memo.
Because we remember objects by their id, we have

View File

@ -22,7 +22,7 @@ def _is_wrapped_exception(obj: Any) -> bool:
class CheckpointException(BaseException):
"""Exception raised if failure was detected as part of a checkpoint load or save."""
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]):
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]) -> None:
super().__init__(msg, failures)
self._failures = failures

View File

@ -401,7 +401,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
"""
def __init__(self, keys=None, *args, **kwargs):
def __init__(self, keys=None, *args, **kwargs) -> None:
self.keys = keys
super().__init__(*args, **kwargs)

View File

@ -66,7 +66,7 @@ def _init_model(rank, world_size):
return model, optim
def _print(msg):
def _print(msg) -> None:
if dist.get_rank() == 0:
print(msg)
@ -80,7 +80,7 @@ def _input():
return x, y
def run(rank, world_size):
def run(rank, world_size) -> None:
# Set up world pg
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

View File

@ -55,7 +55,7 @@ def print_params(stage, model_1, model_2, optim_1, optim_2):
)
def run_fsdp_checkpoint_example(rank, world_size):
def run_fsdp_checkpoint_example(rank, world_size) -> None:
# Set up world pg
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

View File

@ -39,7 +39,7 @@ class Model(torch.nn.Module):
return torch.rand(8, 8, device="cuda")
def _make_stateful(model, optim):
def _make_stateful(model, optim) -> None:
_patch_model_state_dict(model)
_patch_optimizer_state_dict(model, optimizers=optim)
@ -70,7 +70,7 @@ def _init_model(device, world_size):
return model, optim
def run(rank, world_size, device="cuda"):
def run(rank, world_size, device="cuda") -> None:
# Set up world pg
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

View File

@ -258,7 +258,7 @@ class _StorageWriterTransforms:
# are appended.
class NoCloseWriter(io.IOBase):
def __init__(self, raw: io.IOBase):
def __init__(self, raw: io.IOBase) -> None:
self.raw = raw
def writeable(self) -> bool:
@ -267,7 +267,7 @@ class _StorageWriterTransforms:
def write(self, b: Buffer) -> int:
return self.raw.write(b)
def close(self):
def close(self) -> None:
self.flush()
self.raw.flush()
# but not close.

View File

@ -208,7 +208,7 @@ class DynamicMetaLoadPlanner(DefaultLoadPlanner):
def dcp_to_torch_save(
dcp_checkpoint_dir: Union[str, os.PathLike],
torch_save_path: Union[str, os.PathLike],
):
) -> None:
"""
Given a directory containing a DCP checkpoint, this function will convert it into a
Torch save file.
@ -233,7 +233,7 @@ def dcp_to_torch_save(
def torch_save_to_dcp(
torch_save_path: Union[str, os.PathLike],
dcp_checkpoint_dir: Union[str, os.PathLike],
):
) -> None:
"""
Given the location of a torch save file, converts it into a DCP checkpoint.

View File

@ -110,7 +110,7 @@ def _dcp_method_logger(
return decorator
def _init_logger(rank: int):
def _init_logger(rank: int) -> None:
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

View File

@ -177,7 +177,7 @@ class MetadataIndex:
fqn: str,
offset: Optional[Sequence[int]] = None,
index: Optional[int] = None,
):
) -> None:
# We must use object.__setattr__ due to frozen=True
object.__setattr__(self, "fqn", fqn)
object.__setattr__(self, "index", index)

View File

@ -34,7 +34,7 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
thread_count: int = 1,
target_dtype: torch.dtype = torch.float32,
block_size: int = 128,
):
) -> None:
"""
Initialize the HuggingFace storage reader to load quantized checkpoints
@ -66,7 +66,7 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
return metadata
def _load_quantization_metadata(self):
def _load_quantization_metadata(self) -> None:
"""Load quantization metadata from the checkpoint."""
checkpoint_path = Path(self.path)
# Load weight mapping from index file
@ -77,7 +77,7 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
weight_map = index_data.get("weight_map", {})
self._build_weight_scale_mapping(weight_map)
def _build_weight_scale_mapping(self, weight_map: dict[str, str]):
def _build_weight_scale_mapping(self, weight_map: dict[str, str]) -> None:
"""Analyze and build weight-scale tensor pairs from weight mapping."""
# Store the complete weight map for file location lookups
self._weight_map = weight_map

View File

@ -174,7 +174,7 @@ class DefaultStager(AsyncStager):
def __init__(
self,
config: StagingOptions = StagingOptions(),
):
) -> None:
self._config = config
self._state_dict_stager = StateDictStager(
pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory
@ -292,7 +292,7 @@ class BlockingAsyncStager(AsyncStager):
self,
cache_staged_state_dict: bool = False,
type_check: bool = False,
):
) -> None:
"""
Initializes the BlockingAsyncStager.
@ -352,7 +352,7 @@ class _ReplicationStager(AsyncStager):
timeout: timedelta = timedelta(minutes=30),
device: torch.device = torch.device("cpu"),
storage_dir: Optional[str] = None,
):
) -> None:
self._pg = pg
self._timeout = timeout
# pyrefly: ignore [read-only]

View File

@ -1560,7 +1560,7 @@ def _patch_model_state_dict(
options=options,
)
def load_state_dict_call(state_dict: dict[str, Any]):
def load_state_dict_call(state_dict: dict[str, Any]) -> None:
_load_state_dict_call(model_state_dict=state_dict)
model.load_state_dict = load_state_dict_call
@ -1619,7 +1619,7 @@ def _patch_optimizer_state_dict(
options=options,
)
def load_state_dict_call(state_dict: dict[str, Any]):
def load_state_dict_call(state_dict: dict[str, Any]) -> None:
_load_state_dict_call(optim_state_dict=state_dict)
_patched_state_dict.add(state_dict_call)

View File

@ -287,7 +287,7 @@ def _load_state_dict(
central_plan = global_plan[0]
@_dcp_method_logger(**ckpt_kwargs)
def read_data():
def read_data() -> None:
if planner is None:
raise AssertionError("planner is None")
if central_plan is None:

View File

@ -344,7 +344,7 @@ def async_save(
def callback(
original_staging_future: Future[STATE_DICT_TYPE],
return_staging_future: Future[None] = return_staging_future,
):
) -> None:
try:
original_staging_future.result()
return_staging_future.set_result(None)
@ -363,7 +363,7 @@ def async_save(
else:
@_dcp_method_logger(log_exceptions=True)
def maybe_synchronize_staging():
def maybe_synchronize_staging() -> None:
if async_stager.should_synchronize_after_execute:
async_stager.synchronize_staging()

View File

@ -87,7 +87,7 @@ class _DistWrapper:
group: Optional[dist.ProcessGroup],
use_dist: bool,
coordinator_rank: int,
):
) -> None:
self.group = group
self.use_dist = use_dist
self.coordinator_rank = coordinator_rank
@ -383,7 +383,7 @@ def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]:
class _ReaderView(io.IOBase):
def __init__(self, base_stream: io.IOBase, offset: int, len: int):
def __init__(self, base_stream: io.IOBase, offset: int, len: int) -> None:
super().__init__()
self.offset = offset
self.len = len

View File

@ -29,7 +29,7 @@ if not is_available():
class _DeviceMeshStub:
pass
def _init_device_mesh_stub():
def _init_device_mesh_stub() -> None:
pass
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]

View File

@ -372,7 +372,7 @@ class Backend(str): # noqa: SLOT000
class BackendConfig:
"""Backend configuration class."""
def __init__(self, backend: Backend):
def __init__(self, backend: Backend) -> None:
"""Init."""
self.device_backend_map: dict[str, Backend] = {}
# pyrefly: ignore [bad-assignment]
@ -442,7 +442,7 @@ class BackendConfig:
logger.info("Using backend config: %s", self.device_backend_map)
def __repr__(self):
def __repr__(self) -> str:
"""Return all the device:backend pairs separated by commas."""
return ",".join(
f"{device}:{backend}" for device, backend in self.device_backend_map.items()
@ -508,7 +508,7 @@ class P2POp:
group: Optional[ProcessGroup] = None,
tag: int = 0,
group_peer: Optional[int] = None,
):
) -> None:
"""Init."""
self.op = op
self.tensor = tensor
@ -534,7 +534,7 @@ class P2POp:
return object.__new__(cls)
def __repr__(self):
def __repr__(self) -> str:
my_group_rank = get_rank(self.group)
op_name = self.op.__name__
group_name = self.group.group_name if self.group else "default_pg"
@ -569,7 +569,7 @@ class _CollOp:
dst_tensor: Optional[torch.Tensor] = None,
redop: Optional[ReduceOp] = None,
root: Optional[int] = None,
):
) -> None:
self.op = op
self.tensor = tensor
self.dst_tensor = dst_tensor
@ -734,7 +734,7 @@ class _WorldMeta(type):
return _world.default_pg
@WORLD.setter
def WORLD(cls, pg: Optional[ProcessGroup]):
def WORLD(cls, pg: Optional[ProcessGroup]) -> None:
_world.default_pg = pg
@ -1180,7 +1180,7 @@ def _canonicalize_group_rank(
return group_rank
def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str):
def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str) -> None:
if group.rank() == rank:
raise ValueError(
f"Invalid {rank_type} rank: {rank_type} rank should not be the same as "
@ -1832,7 +1832,7 @@ def init_process_group(
old_hook = sys.excepthook
excepthook_prefix = f"[rank{get_rank()}]"
def _distributed_excepthook(*args):
def _distributed_excepthook(*args) -> None:
old_stderr = sys.stderr
sys.stderr = buf = io.StringIO()
try:
@ -2216,7 +2216,7 @@ def _new_process_group_helper(
return pg, prefix_store
def destroy_process_group(group: Optional[ProcessGroup] = None):
def destroy_process_group(group: Optional[ProcessGroup] = None) -> None:
"""
Destroy a given process group, and deinitialize the distributed package.
@ -2305,7 +2305,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
_unregister_process_group(pg.group_name)
def _abort_process_group(group: Optional[ProcessGroup] = None):
def _abort_process_group(group: Optional[ProcessGroup] = None) -> None:
"""
Abort a given process group. If group.WORLD (i.e. `None`) is given, all
process groups including the default one will be aborted.
@ -2623,11 +2623,11 @@ class _CoalescingManager:
def __init__(self) -> None:
self.works: list[Work] = []
def append(self, work: Optional[Work] = None):
def append(self, work: Optional[Work] = None) -> None:
if work:
self.works.append(work)
def wait(self):
def wait(self) -> None:
for work in self.works:
work.wait()
@ -3171,7 +3171,7 @@ def _tensor_to_object(tensor, tensor_size, group):
@_exception_logger
def all_gather_object(object_list, obj, group=None):
def all_gather_object(object_list, obj, group=None) -> None:
"""
Gathers picklable objects from the whole group into a list.
@ -3272,7 +3272,7 @@ def gather_object(
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
group_dst: Optional[int] = None,
):
) -> None:
"""
Gathers picklable objects from the whole group in a single process.
@ -3404,7 +3404,7 @@ def send_object_list(
device: Optional[torch.device] = None,
group_dst: Optional[int] = None,
use_batch: bool = False,
):
) -> None:
"""
Sends picklable objects in ``object_list`` synchronously.
@ -3663,7 +3663,7 @@ def broadcast_object_list(
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group_src: Optional[int] = None,
):
) -> None:
"""
Broadcasts picklable objects in ``object_list`` to the whole group.
@ -3795,7 +3795,7 @@ def scatter_object_list(
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
group_src: Optional[int] = None,
):
) -> None:
"""
Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
@ -4250,7 +4250,7 @@ def all_gather_coalesced(
# Otherwise, the backend has sync'ed at CPP level
def _validate_output_list_for_rank(my_rank, dst, gather_list):
def _validate_output_list_for_rank(my_rank, dst, gather_list) -> None:
if dst == my_rank:
if not gather_list:
raise ValueError(

View File

@ -162,7 +162,7 @@ class Worker:
role_rank: int = -1,
world_size: int = -1,
role_world_size: int = -1,
):
) -> None:
# unique identifier for this worker
self.id: Any = None
@ -188,14 +188,14 @@ class Worker:
# the role world size may change between re-rendezvous.
self.role_world_size: int = role_world_size
def __str__(self):
def __str__(self) -> str:
return (
f"local_rank={self.local_rank},global_rank={self.global_rank}"
f",role_rank={self.role_rank},world_size={self.world_size}"
f",role_world_size={self.role_world_size}"
)
def __repr__(self):
def __repr__(self) -> str:
return str(self)
@ -271,7 +271,7 @@ class WorkerGroup:
"master_port",
]
def __init__(self, spec: WorkerSpec):
def __init__(self, spec: WorkerSpec) -> None:
self.spec = spec
self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
@ -295,7 +295,7 @@ class _RoleInstanceInfo:
__slots__ = ["role", "rank", "local_world_size"]
def __init__(self, role: str, rank: int, local_world_size: int):
def __init__(self, role: str, rank: int, local_world_size: int) -> None:
r"""Initialize the agent class instance.
Args:
@ -449,7 +449,7 @@ class SimpleElasticAgent(ElasticAgent):
such as one particular type of worker role.
"""
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300) -> None:
self._worker_group = WorkerGroup(spec)
self._remaining_restarts = self._worker_group.spec.max_restarts
self._store = None
@ -845,7 +845,7 @@ class SimpleElasticAgent(ElasticAgent):
f"torchelastic.worker.status.{state}", source=source, metadata=metadata
)
def _record_metrics(self, group_results: RunResult):
def _record_metrics(self, group_results: RunResult) -> None:
is_failed = group_results.is_failed()
self._record_flakiness_metric(is_failed)
spec = self._worker_group.spec
@ -864,14 +864,14 @@ class SimpleElasticAgent(ElasticAgent):
"run_failed_no_retries", is_failed and not restarts_happened
)
def _record_metric_with_condition(self, metric_name, condition):
def _record_metric_with_condition(self, metric_name, condition) -> None:
spec = self._worker_group.spec
if condition:
put_metric(f"workers.{spec.role}.{metric_name}", 1)
else:
put_metric(f"workers.{spec.role}.{metric_name}", 0)
def _record_flakiness_metric(self, is_failed: bool = False):
def _record_flakiness_metric(self, is_failed: bool = False) -> None:
if is_failed:
flakiness = 100.0
else:
@ -952,7 +952,7 @@ class SimpleElasticAgent(ElasticAgent):
f"[{role}] Worker group in {state.name} state"
)
def _exit_barrier(self):
def _exit_barrier(self) -> None:
"""
Define a barrier that keeps the agent process alive until all workers finish.

View File

@ -153,7 +153,7 @@ class LocalElasticAgent(SimpleElasticAgent):
start_method="spawn",
exit_barrier_timeout: float = 300,
log_line_prefix_template: Optional[str] = None,
):
) -> None:
super().__init__(spec, exit_barrier_timeout)
self._start_method = start_method
self._pcontext: Optional[PContext] = None

View File

@ -44,7 +44,7 @@ class Event:
timestamp: int = 0
metadata: dict[str, EventMetadataValue] = field(default_factory=dict)
def __str__(self):
def __str__(self) -> str:
return self.serialize()
@staticmethod
@ -99,7 +99,7 @@ class RdzvEvent:
local_id: Optional[int] = None
error_trace: str = ""
def __str__(self):
def __str__(self) -> str:
return self.serialize()
@staticmethod

View File

@ -158,7 +158,7 @@ from .api import ( # noqa: F401
)
def initialize_metrics(cfg: Optional[MetricsConfig] = None):
def initialize_metrics(cfg: Optional[MetricsConfig] = None) -> None:
pass

View File

@ -37,7 +37,7 @@ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value
class MetricsConfig:
__slots__ = ["params"]
def __init__(self, params: Optional[dict[str, str]] = None):
def __init__(self, params: Optional[dict[str, str]] = None) -> None:
self.params = params
if self.params is None:
self.params = {}
@ -50,23 +50,23 @@ class MetricHandler(abc.ABC):
class ConsoleMetricHandler(MetricHandler):
def emit(self, metric_data: MetricData):
def emit(self, metric_data: MetricData) -> None:
print(
f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
)
class NullMetricHandler(MetricHandler):
def emit(self, metric_data: MetricData):
def emit(self, metric_data: MetricData) -> None:
pass
class MetricStream:
def __init__(self, group_name: str, handler: MetricHandler):
def __init__(self, group_name: str, handler: MetricHandler) -> None:
self.group_name = group_name
self.handler = handler
def add_value(self, metric_name: str, metric_value: int):
def add_value(self, metric_name: str, metric_value: int) -> None:
self.handler.emit(
MetricData(time.time(), self.group_name, metric_name, metric_value)
)
@ -77,7 +77,7 @@ _default_metrics_handler: MetricHandler = NullMetricHandler()
# pyre-fixme[9]: group has type `str`; used as `None`.
def configure(handler: MetricHandler, group: Optional[str] = None):
def configure(handler: MetricHandler, group: Optional[str] = None) -> None:
if group is None:
global _default_metrics_handler
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
@ -188,7 +188,9 @@ def profile(group=None):
return wrap
def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
def put_metric(
metric_name: str, metric_value: int, metric_group: str = "torchelastic"
) -> None:
"""
Publish a metric data point.
@ -206,7 +208,7 @@ def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchel
"Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead",
category=FutureWarning,
)
def publish_metric(metric_group: str, metric_name: str, metric_value: int):
def publish_metric(metric_group: str, metric_name: str, metric_value: int) -> None:
metric_stream = getStream(metric_group)
metric_stream.add_value(metric_name, metric_value)

View File

@ -102,7 +102,7 @@ def _get_default_signal() -> signal.Signals:
return signal.SIGTERM
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str):
def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str) -> None:
actual_keys = set(d.keys())
expected_keys = set(range(nprocs))
@ -472,7 +472,7 @@ class PContext(abc.ABC):
log_line_prefixes: Optional[dict[int, str]] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
):
) -> None:
self.name = name
# validate that all mappings have the same number of keys and
# all local ranks are accounted for
@ -715,7 +715,7 @@ class MultiprocessContext(PContext):
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
):
) -> None:
super().__init__(
name,
entrypoint,
@ -743,7 +743,7 @@ class MultiprocessContext(PContext):
self._numa_options: Optional[NumaOptions] = numa_options
def _start(self):
def _start(self) -> None:
if self._pc:
raise ValueError(
"The process context already initialized."
@ -904,7 +904,7 @@ class SubprocessContext(PContext):
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
):
) -> None:
super().__init__(
name,
entrypoint,
@ -922,7 +922,7 @@ class SubprocessContext(PContext):
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
self._numa_options: Optional[NumaOptions] = numa_options
def _start(self):
def _start(self) -> None:
if self.subprocess_handlers:
raise ValueError(
"The subprocess handlers already initialized. Most likely the start method got called twice."
@ -940,7 +940,7 @@ class SubprocessContext(PContext):
for local_rank in range(self.nprocs)
}
def _capture_process_failures(self, done_local_ranks: set[int]):
def _capture_process_failures(self, done_local_ranks: set[int]) -> None:
for local_rank in self._running_local_ranks:
handler = self.subprocess_handlers[local_rank]
exitcode = handler.proc.poll()

View File

@ -157,7 +157,7 @@ class ProcessFailure:
timestamp = int(message["extraInfo"]["timestamp"])
return (message, timestamp)
def _set_no_reply_file(self):
def _set_no_reply_file(self) -> None:
self.error_file = _NOT_AVAILABLE
self.error_file_data = _EMPTY_ERROR_DATA
self.message = ""
@ -237,7 +237,7 @@ class ChildFailedError(Exception):
of trainer 1's error file to the scheduler's init process.
"""
def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]):
def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]) -> None:
self.name = name
self.failures = failures
assert (

View File

@ -92,7 +92,7 @@ class ErrorHandler:
rootcause_error_file: str,
rootcause_error: dict[str, Any],
error_code: int = 0,
):
) -> None:
"""Modify the rootcause_error read from the file, to correctly set the exit code."""
if "message" not in rootcause_error:
logger.warning(
@ -110,7 +110,7 @@ class ErrorHandler:
else:
rootcause_error["message"]["errorCode"] = error_code
def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
def dump_error_file(self, rootcause_error_file: str, error_code: int = 0) -> None:
"""Dump parent error file from child process's root cause error and error code."""
with open(rootcause_error_file) as fp:
rootcause_error = json.load(fp)
@ -147,7 +147,7 @@ class ErrorHandler:
rootcause_error_file,
)
def _rm(self, my_error_file):
def _rm(self, my_error_file) -> None:
if os.path.isfile(my_error_file):
# Log the contents of the original file.
with open(my_error_file) as fp:

View File

@ -87,7 +87,7 @@ def redirect(std: str, to_file: str):
python_std = _python_std(std)
std_fd = python_std.fileno()
def _redirect(dst):
def _redirect(dst) -> None:
libc.fflush(c_std)
python_std.flush()
os.dup2(dst.fileno(), std_fd)

View File

@ -42,7 +42,7 @@ class SubprocessHandler:
stderr: Optional[str],
local_rank_id: int,
numa_options: Optional[NumaOptions],
):
) -> None:
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
# inherit parent environment vars

View File

@ -32,7 +32,7 @@ def tail_logfile(
finished: Event,
interval_sec: float,
log_line_filter: Optional[Callable[[str], bool]] = None,
):
) -> None:
while not os.path.exists(file):
if finished.is_set():
return
@ -102,7 +102,7 @@ class TailLog:
log_line_prefixes: Optional[dict[int, str]] = None,
interval_sec: float = 0.1,
log_line_filter: Callable[[str], bool] = (lambda _: True),
):
) -> None:
n = len(log_files)
self._threadpool = None
if n > 0:

View File

@ -115,7 +115,7 @@ class RendezvousInfo:
rank: int,
world_size: int,
bootstrap_store_info: RendezvousStoreInfo,
):
) -> None:
self._store = store
self._rank = rank
self._world_size = world_size
@ -267,7 +267,7 @@ class RendezvousParameters:
max_nodes: int,
local_addr: Optional[str] = None,
**kwargs,
):
) -> None:
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")

View File

@ -183,7 +183,7 @@ class RendezvousTimeout:
"""Get the keep-alive heartbeat timeout."""
return self._heartbeat
def _set_timeouts(self, **timeouts: Optional[timedelta]):
def _set_timeouts(self, **timeouts: Optional[timedelta]) -> None:
for name, timeout in timeouts.items():
if timeout is None:
timeout = self._DEFAULT_TIMEOUTS[name]
@ -395,7 +395,7 @@ class _BackendRendezvousStateHolder(_RendezvousStateHolder):
self._last_sync_time = -1
self._dead_nodes = []
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING):
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
run_id=self._settings.run_id,

View File

@ -153,7 +153,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
+--------------------------------------------+--------------------------+
"""
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]):
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]) -> None:
"""
Args:
rdzv_impl: the implementation of the rendezvous
@ -163,7 +163,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
self._rdzv_impl = rdzv_impl
self._local_addr = local_addr
def __del__(self):
def __del__(self) -> None:
# TODO: look into using weakref here instead.
del self._rdzv_impl
@ -189,7 +189,7 @@ class EtcdRendezvousHandler(RendezvousHandler):
# No rendezvous state, so it cannot be closed.
return False
def set_closed(self):
def set_closed(self) -> None:
self._rdzv_impl.set_closed()
def num_nodes_waiting(self):
@ -231,7 +231,7 @@ class EtcdRendezvous:
num_max_workers,
timeout,
last_call_timeout,
):
) -> None:
self.client = client
logger.info("Etcd machines: %s", self.client.machines)
@ -270,7 +270,7 @@ class EtcdRendezvous:
except etcd.EtcdAlreadyExist:
pass
def __del__(self):
def __del__(self) -> None:
# TODO: look into using weakref here instead.
if self._lease_run_id_stop is not None:
self._lease_run_id_stop.set()
@ -442,7 +442,7 @@ class EtcdRendezvous:
# Rendezvous version number; our rank in it; world size
return state["version"], this_rank, len(state["participants"])
def handle_existing_rendezvous(self, expected_version):
def handle_existing_rendezvous(self, expected_version) -> None:
"""
Handle the case when there's an existing (state 'final) rendezvous already
in place, and we have to announce ourselves waiting, and wait until
@ -678,7 +678,7 @@ class EtcdRendezvous:
except etcd.EtcdCompareFailed:
logger.info("Announce self as waiting CAS unsuccessful, retrying")
def wait_for_rendezvous_to_free(self, expected_version):
def wait_for_rendezvous_to_free(self, expected_version) -> None:
"""
When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join.
@ -746,7 +746,7 @@ class EtcdRendezvous:
raise RendezvousTimeoutError
active_version, state = self.get_rdzv_state()
def handle_join_last_call(self, expected_version, deadline):
def handle_join_last_call(self, expected_version, deadline) -> None:
"""
After we reach min number of workers, one particular worker takes on the
responsibility of waiting an additional timeout before closing the join window.
@ -820,7 +820,7 @@ class EtcdRendezvous:
cas_delay()
active_version, state = self.get_rdzv_state()
def set_closed(self):
def set_closed(self) -> None:
"""
Mark rendezvous 'closed' for current run_id, which is used to signal other
participants to not attempt to perform (re-)rendezvous. This is useful
@ -868,13 +868,13 @@ class EtcdRendezvous:
# Unfortunately, we have to do another fetch in order to get last etcd_index.
return self.get_rdzv_state()
def get_path(self, path):
def get_path(self, path) -> str:
if not path.startswith("/"):
path = "/" + path
return f"{self._prefix}run_{self._run_id}{path}"
def create_path_if_not_exists(self, full_path, ttl=None):
def create_path_if_not_exists(self, full_path, ttl=None) -> None:
try:
self.client.write(
key=full_path, value=None, dir=True, prevExist=False, ttl=ttl
@ -888,7 +888,7 @@ class EtcdRendezvous:
# release the Python's GIL! An example of this is calling a pybind11
# extension function that is blocking / long-running, but is not
# doing a scoped release of the GIL.
def lease_worker(client, path, ttl, stop_event):
def lease_worker(client, path, ttl, stop_event) -> None:
while True:
try:
client.refresh(path, ttl=ttl)
@ -912,7 +912,7 @@ class EtcdRendezvous:
return lease_stop_event
def store_extra_data(self, rdzv_version, key, value):
def store_extra_data(self, rdzv_version, key, value) -> None:
node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
try:
# If first time we are storing anything:

View File

@ -64,7 +64,7 @@ def find_free_port():
raise RuntimeError("Failed to create a socket")
def stop_etcd(subprocess, data_dir: Optional[str] = None):
def stop_etcd(subprocess, data_dir: Optional[str] = None) -> None:
if subprocess and subprocess.poll() is None:
logger.info("stopping etcd server")
subprocess.terminate()
@ -107,7 +107,7 @@ class EtcdServer:
etcd_binary_path: path of etcd server binary (see above for fallback path)
"""
def __init__(self, data_dir: Optional[str] = None):
def __init__(self, data_dir: Optional[str] = None) -> None:
self._port = -1
self._host = "localhost"

View File

@ -23,7 +23,7 @@ except ModuleNotFoundError:
# Delay (sleep) for a small random amount to reduce CAS failures.
# This does not affect correctness, but will reduce requests to etcd server.
def cas_delay():
def cas_delay() -> None:
time.sleep(random.uniform(0, 0.1))
@ -41,7 +41,7 @@ class EtcdStore(Store):
etcd_store_prefix,
# Default timeout same as in c10d/Store.hpp
timeout: Optional[datetime.timedelta] = None,
):
) -> None:
super().__init__() # required for pybind trampoline.
self.client = etcd_client
@ -53,7 +53,7 @@ class EtcdStore(Store):
if not self.prefix.endswith("/"):
self.prefix += "/"
def set(self, key, value):
def set(self, key, value) -> None:
"""
Write a key/value pair into ``EtcdStore``.
@ -121,7 +121,7 @@ class EtcdStore(Store):
except etcd.EtcdCompareFailed:
cas_delay()
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None) -> None:
"""
Wait until all of the keys are published, or until timeout.

View File

@ -44,7 +44,7 @@ class StaticTCPRendezvous(RendezvousHandler):
world_size: int,
run_id: str,
timeout: int,
):
) -> None:
self.master_addr = master_addr
self.master_port = master_port
self.rank = rank
@ -82,13 +82,13 @@ class StaticTCPRendezvous(RendezvousHandler):
bootstrap_store_info,
)
def is_closed(self):
def is_closed(self) -> bool:
return False
def set_closed(self):
def set_closed(self) -> None:
pass
def num_nodes_waiting(self):
def num_nodes_waiting(self) -> int:
return 0
def get_run_id(self) -> str:

View File

@ -279,7 +279,7 @@ class _PeriodicTimer:
ctx.function(*ctx.args, **ctx.kwargs)
@staticmethod
def _stop_thread(thread, stop_event):
def _stop_thread(thread, stop_event) -> None:
stop_event.set()
thread.join()

View File

@ -39,7 +39,7 @@ class TimerRequest:
__slots__ = ["worker_id", "scope_id", "expiration_time"]
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float) -> None:
self.worker_id = worker_id
self.scope_id = scope_id
self.expiration_time = expiration_time
@ -119,7 +119,7 @@ class TimerServer(abc.ABC):
def __init__(
self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
):
) -> None:
"""
:param request_queue: Consumer ``RequestQueue``
:param max_interval: max time (in seconds) to wait
@ -179,14 +179,14 @@ class TimerServer(abc.ABC):
)
return True
def _watchdog_loop(self):
def _watchdog_loop(self) -> None:
while not self._stop_signaled:
try:
self._run_watchdog()
except Exception:
logger.exception("Error running watchdog")
def _run_watchdog(self):
def _run_watchdog(self) -> None:
batch_size = max(1, self._request_queue.size())
timer_requests = self._request_queue.get(batch_size, self._max_interval)
self.register_timers(timer_requests)
@ -237,7 +237,7 @@ class TimerServer(abc.ABC):
_timer_client: Optional[TimerClient] = None
def configure(timer_client: TimerClient):
def configure(timer_client: TimerClient) -> None:
"""
Configures a timer client. Must be called before using ``expires``.
"""

View File

@ -19,6 +19,6 @@ __all__ = ["log_debug_info_for_expired_timers"]
def log_debug_info_for_expired_timers(
run_id: str,
expired_timers: dict[int, list[str]],
):
) -> None:
if expired_timers:
logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)

View File

@ -263,7 +263,7 @@ class FileTimerServer:
os.remove(self._file_path)
@staticmethod
def is_process_running(pid: int):
def is_process_running(pid: int) -> bool | None:
"""
function to check process is running or not
"""

View File

@ -29,16 +29,16 @@ class LocalTimerClient(TimerClient):
GPU devices.
"""
def __init__(self, mp_queue):
def __init__(self, mp_queue) -> None:
super().__init__()
self._mp_queue = mp_queue
def acquire(self, scope_id, expiration_time):
def acquire(self, scope_id, expiration_time) -> None:
pid = os.getpid()
acquire_request = TimerRequest(pid, scope_id, expiration_time)
self._mp_queue.put(acquire_request)
def release(self, scope_id):
def release(self, scope_id) -> None:
pid = os.getpid()
release_request = TimerRequest(pid, scope_id, -1)
self._mp_queue.put(release_request)
@ -49,7 +49,7 @@ class MultiprocessingRequestQueue(RequestQueue):
A ``RequestQueue`` backed by python ``multiprocessing.Queue``
"""
def __init__(self, mp_queue: mp.Queue):
def __init__(self, mp_queue: mp.Queue) -> None:
super().__init__()
self._mp_queue = mp_queue
@ -86,7 +86,7 @@ class LocalTimerServer(TimerServer):
def __init__(
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
):
) -> None:
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
self._timers: dict[tuple[Any, str], TimerRequest] = {}

View File

@ -36,7 +36,7 @@ class CyclingIterator(Iterator[_T]):
n: int,
generator_fn: Callable[[int], Iterator[_T]],
start_epoch: int = 0,
):
) -> None:
self._n = n
self._epoch = start_epoch
self._generator_fn = generator_fn

View File

@ -47,7 +47,7 @@ class ElasticDistributedSampler(DistributedSampler[T]):
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
start_index: int = 0,
):
) -> None:
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
if not isinstance(dataset, Sized):
raise TypeError("Dataset must be an instance of collections.abc.Sized")

View File

@ -115,7 +115,7 @@ def create_c10d_store(
raise
def _check_full_rank(store, world_size, timeout):
def _check_full_rank(store, world_size, timeout) -> None:
try:
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
except RuntimeError as e:

View File

@ -3,7 +3,9 @@ import torch
from torch.distributed._tools import MemoryTracker
def run_one_model(net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"):
def run_one_model(
net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"
) -> None:
net.to(device)
input = input.to(device)

View File

@ -61,7 +61,7 @@ class _FSDPDeviceHandle:
semantics to be integrated with FSDP.
"""
def __init__(self, device: torch.device, backend: Any = None):
def __init__(self, device: torch.device, backend: Any = None) -> None:
if backend is None:
try:
self.__backend = getattr(torch, device.type)
@ -187,7 +187,7 @@ class HandleTrainingState(Enum):
SUMMON_FULL_PARAMS = auto()
def _is_composable(state: _FSDPState):
def _is_composable(state: _FSDPState) -> bool:
# TODO: This is a temporary hack for differentiate between code paths.
return not isinstance(state, nn.Module)
@ -305,7 +305,7 @@ def _get_param_to_fqns(
includes the FQNs across all encounters. (Default: ``True``)
"""
def module_fn(module, prefix, tree_level, param_to_fqns):
def module_fn(module, prefix, tree_level, param_to_fqns) -> None:
for param_name, param in _named_parameters_with_duplicates(
module, recurse=False
):
@ -400,7 +400,9 @@ def _apply_to_modules(
to remove the prefix.
"""
def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
def f(
module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs
) -> None:
# Call the module function before recursing over children (pre-order)
module_fn(module, prefix, tree_level, *args, **kwargs)
for submodule_name, submodule in module.named_children():

View File

@ -106,7 +106,7 @@ def _get_sharded_module_tree_with_module_name_to_fqns(
def module_fn(
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
):
) -> None:
num_spaces = tree_level * 4
trimed_prefix = (
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix

View File

@ -352,7 +352,7 @@ class _ExecOrderData:
fqns.append(self.param_to_fqn[flat_param])
return fqns
def next_iter(self):
def next_iter(self) -> None:
"""
Advances the internal data structures per iteration. This should be
called in the post-backward callback since that marks the true end of

View File

@ -192,7 +192,7 @@ class FlatParamShardMetadata(NamedTuple):
class _FlatParameterMeta(_ParameterMeta):
# Make `isinstance(t, FlatParameter)` return True for custom tensor
# instances that have the _is_flat_param flag for BC
def __instancecheck__(self, instance):
def __instancecheck__(self, instance) -> bool:
# NB: do NOT test the super implementation
return isinstance(instance, torch.Tensor) and getattr(
instance, "_is_flat_param", False
@ -525,7 +525,7 @@ class FlatParamHandle:
use_orig_params: bool,
*,
fsdp_extension: Optional[FSDPExtensions] = None,
):
) -> None:
super().__init__()
params = list(params)
if len(params) == 0:
@ -621,10 +621,10 @@ class FlatParamHandle:
)
self._use_unsharded_views(as_params=False)
def __repr__(self):
def __repr__(self) -> str:
return f"FlatParamHandle(flat_param.fqns={self.flat_param._fqns})"
def _init_setattr_fns(self):
def _init_setattr_fns(self) -> None:
use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
@ -635,7 +635,7 @@ class FlatParamHandle:
self._setattr_tensor = _safe_setattr_tensor_or_param
self._setattr_param = _safe_setattr_tensor_or_param
def _init_get_unflat_views_fn(self, align_addresses: bool):
def _init_get_unflat_views_fn(self, align_addresses: bool) -> None:
self._get_unflat_views = (
self._get_unflat_views_aligned
if align_addresses
@ -945,7 +945,7 @@ class FlatParamHandle:
# SHARD INITIALIZATION & METADATA #
###################################
@torch.no_grad()
def shard(self):
def shard(self) -> None:
"""
Shard the handle's ``FlatParameter``.
@ -1342,7 +1342,7 @@ class FlatParamHandle:
self._check_on_compute_device(self.flat_param)
return ret
def _use_low_precision_shard(self):
def _use_low_precision_shard(self) -> None:
"""Allocate on the compute device and switch to using the low precision sharded flat parameter."""
self._check_low_precision_shard()
flat_param = self.flat_param
@ -1359,7 +1359,7 @@ class FlatParamHandle:
# Invariant: `_mp_shard` is always on the compute device.
flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
def unshard(self):
def unshard(self) -> None:
"""
Run the unshard logic.
@ -1523,7 +1523,7 @@ class FlatParamHandle:
elif in_forward:
self._use_unsharded_views(as_params=False)
def post_unshard(self):
def post_unshard(self) -> None:
"""
Run the post-unshard logic.
@ -1533,7 +1533,7 @@ class FlatParamHandle:
self._free_low_precision_sharded_param()
self._check_on_compute_device(self.flat_param)
def _free_low_precision_sharded_param(self):
def _free_low_precision_sharded_param(self) -> None:
"""Frees the low precision sharded flat parameter."""
self._check_low_precision_shard()
# `_mp_shard` is allocated in the pre-unshard stream, consumed in the
@ -1550,7 +1550,7 @@ class FlatParamHandle:
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
@torch.no_grad()
def unshard_grad(self):
def unshard_grad(self) -> None:
"""
Unshard the handle's ``FlatParameter``'s gradient.
@ -1608,7 +1608,7 @@ class FlatParamHandle:
)
self._use_unsharded_grad_views()
def reshard_grad(self):
def reshard_grad(self) -> None:
if self._use_orig_params:
self._use_sharded_grad_views()
if not self.uses_sharded_strategy:
@ -1616,7 +1616,7 @@ class FlatParamHandle:
self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
delattr(self.flat_param, "_saved_grad_shard")
def prepare_gradient_for_backward(self):
def prepare_gradient_for_backward(self) -> None:
"""
Prepare the gradient for the backward computation.
@ -1681,10 +1681,10 @@ class FlatParamHandle:
)
flat_param.grad = None
def prepare_gradient_for_optim(self):
def prepare_gradient_for_optim(self) -> None:
"""Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
def cast_grad_to_param_dtype_if_needed(flat_param):
def cast_grad_to_param_dtype_if_needed(flat_param) -> None:
# TODO (rohan-varma): test for full precision with keep_low_precision_grads
if not self._force_full_precision and self._keep_low_precision_grads:
_p_assert(flat_param.grad is not None, "Unexpected None grad!")
@ -1768,7 +1768,7 @@ class FlatParamHandle:
)
self._use_unsharded_flat_param(padded_unsharded_flat_param)
def reshard(self, free_unsharded_flat_param: bool):
def reshard(self, free_unsharded_flat_param: bool) -> None:
"""
Run the reshard logic.
@ -1786,7 +1786,7 @@ class FlatParamHandle:
if free_unsharded_flat_param:
self._free_unsharded_flat_param()
def post_reshard(self):
def post_reshard(self) -> None:
"""
Run the post-reshard logic.
@ -1807,7 +1807,7 @@ class FlatParamHandle:
):
self._free_low_precision_sharded_param()
def _free_unsharded_flat_param(self):
def _free_unsharded_flat_param(self) -> None:
"""
Free the padded unsharded flat parameter. We allow this
function to be called even when storage is not allocated
@ -2451,7 +2451,7 @@ class FlatParamHandle:
raise AssertionError("Expected _is_grad_none_mask to be not None")
self.flat_param._is_grad_none_mask[tensor_index] = True
def _reset_flat_param_grad_info_if_needed(self):
def _reset_flat_param_grad_info_if_needed(self) -> None:
"""
Reset ``flat_param.grad`` if needed.
@ -2480,7 +2480,7 @@ class FlatParamHandle:
# must require gradient
flat_param.requires_grad = requires_grad
def _deregister_orig_params(self):
def _deregister_orig_params(self) -> None:
for param_info in self.flat_param._param_infos:
param_name, module, _ = param_info
if hasattr(module, param_name):
@ -2492,7 +2492,7 @@ class FlatParamHandle:
###########
# HELPERS #
###########
def flat_param_to(self, *args, **kwargs):
def flat_param_to(self, *args, **kwargs) -> None:
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
# pyrefly: ignore [not-iterable]
self.flat_param.data = self.flat_param.to(*args, **kwargs)
@ -2628,23 +2628,23 @@ class FlatParamHandle:
#######################
# CHECKS & INVARIANTS #
#######################
def _check_sharded_strategy(self):
def _check_sharded_strategy(self) -> None:
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
def _check_on_compute_device(self, tensor: Tensor):
def _check_on_compute_device(self, tensor: Tensor) -> None:
_p_assert(
tensor.device == self.device,
f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
)
def _check_on_cpu(self, tensor: Tensor):
def _check_on_cpu(self, tensor: Tensor) -> None:
_p_assert(
tensor.device == torch.device("cpu"),
f"Expects tensor to be on CPU but got {tensor.device}",
)
@staticmethod
def _check_storage_freed(tensor: Tensor):
def _check_storage_freed(tensor: Tensor) -> None:
# Compile does not resize during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
_p_assert(
@ -2653,10 +2653,10 @@ class FlatParamHandle:
)
@staticmethod
def _check_storage_allocated(tensor: Tensor):
def _check_storage_allocated(tensor: Tensor) -> None:
_p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
def _check_low_precision_shard(self):
def _check_low_precision_shard(self) -> None:
_p_assert(
self._uses_param_mixed_precision,
"Not using low precision for parameters",
@ -2671,7 +2671,7 @@ class FlatParamHandle:
f"Expects the low precision shard to be on {self.device} but got {device}",
)
def _check_unsharded(self, tensor: Tensor):
def _check_unsharded(self, tensor: Tensor) -> None:
msg_prefix = "Expects tensor to be unsharded "
_p_assert(tensor is not None, msg_prefix + "but got `None`")
unsharded_size = self.flat_param._unpadded_unsharded_size
@ -2680,7 +2680,7 @@ class FlatParamHandle:
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
)
def _check_sharded(self, tensor: Tensor):
def _check_sharded(self, tensor: Tensor) -> None:
msg_prefix = "Expects tensor to be sharded "
_p_assert(tensor is not None, msg_prefix + "but got `None`")
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
@ -2746,7 +2746,7 @@ def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -
def _safe_setattr_tensor_or_param(
module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
):
) -> None:
# Call `delattr()` and `setattr()` to go through `nn.Module` checks
if hasattr(module, param_name):
delattr(module, param_name)
@ -2804,19 +2804,19 @@ def _construct_padding_tensor(
# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
# message is passed in)
@functools.lru_cache(1)
def _warn_skip_writeback_check(log: logging.Logger, warning: str):
def _warn_skip_writeback_check(log: logging.Logger, warning: str) -> None:
logger.warning(warning)
# Use `lru_cache(1)` to only log the warning once
@functools.lru_cache(1)
def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
def _warn_use_fake_all_gather(log: logging.Logger, warning: str) -> None:
logger.warning(warning)
# Use `lru_cache(1)` to only log the warning once
@functools.lru_cache(1)
def _warn_use_fake_reduce(log: logging.Logger, warning: str):
def _warn_use_fake_reduce(log: logging.Logger, warning: str) -> None:
logger.warning(warning)

View File

@ -60,7 +60,7 @@ class DefaultAllocMixin:
class ProcessGroupAllocMixin:
def __init__(self, group: dist.ProcessGroup, *args: Any, **kwargs: Any):
def __init__(self, group: dist.ProcessGroup, *args: Any, **kwargs: Any) -> None:
self._group = group
super().__init__(*args, **kwargs)

View File

@ -77,7 +77,7 @@ lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()")
@torch.library.impl(lib, "copy_", "HPU")
@torch.library.impl(lib, "copy_", "CPU")
@torch.library.impl(lib, "copy_", "MTIA")
def copy_(tensor, data):
def copy_(tensor, data) -> None:
tensor.copy_(data)
@ -130,7 +130,7 @@ But maybe there are few enough mutations induced by FSDP for this to matter.
@torch.library.impl(lib, "copy_", "Functionalize")
def copy__functionalize(tensor, data):
def copy__functionalize(tensor, data) -> None:
torch._sync(tensor)
torch._sync(data)
tensor_inner = torch._from_functional_tensor(tensor)
@ -185,7 +185,7 @@ class ExtensionsData:
# Save the all-gather input sizes to unflatten the all-gather outputs to ND
all_gather_input_sizes: Sequence[torch.Size] = () # ND
def clear(self):
def clear(self) -> None:
self.all_gather_metadata = None
self.all_gather_input_sizes = ()
@ -229,7 +229,7 @@ class FSDPParam:
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
mp_policy: MixedPrecisionPolicy,
offload_policy: OffloadPolicy,
):
) -> None:
self._module_info: ParamModuleInfo = module_info
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
@ -262,7 +262,7 @@ class FSDPParam:
param: nn.Parameter,
device: torch.device,
shard_placement_fn: Optional[Callable],
):
) -> None:
if param.device != device and param.device.type != "meta":
raise AssertionError(
f"Expects the parameter to already be moved to device {device} but got {param.device}"
@ -434,7 +434,7 @@ class FSDPParam:
self.sharded_post_forward_size
)
def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy) -> None:
param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
self.orig_dtype = self.sharded_param.dtype
# Clamp `reduce_dtype` to `None` if no casting is required: since
@ -469,7 +469,7 @@ class FSDPParam:
world_size: int,
device: torch.device,
force_recreate: bool = False,
):
) -> None:
if not force_recreate and len(self.all_gather_outputs) > 0:
return # already initialized
self.all_gather_outputs = [
@ -477,7 +477,7 @@ class FSDPParam:
for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
]
def init_unsharded_param(self):
def init_unsharded_param(self) -> None:
"""
[Note: Invariants for torch.compile Traceable FSDP2]
1. Under compile, we always re-populate the content of `self._unsharded_param`
@ -863,7 +863,7 @@ class FSDPParam:
f"Expects to be in one of {states}, not {self.sharded_state}"
)
def reset_sharded_param(self):
def reset_sharded_param(self) -> None:
# For ops like `nn.Module._apply` or `load_state_dict(assign=True)`
# that change the sharded parameter tensor, we may need to re-pad the
# sharded local tensor and re-save the reference.
@ -930,7 +930,7 @@ class FSDPParam:
)
self._sharding_spec = self.sharded_param._spec
def __repr__(self):
def __repr__(self) -> str:
return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})"

View File

@ -58,7 +58,7 @@ reference to avoid holding onto memory after forward.
class FSDPCommContext:
"""This has the communication state shared across FSDP states/parameter groups."""
def lazy_init(self, device: torch.device):
def lazy_init(self, device: torch.device) -> None:
self.device_handle = _get_device_handle(device.type)
# Setting the all-gather/reduce-scatter streams to be higher priority
# can help avoid some issues where their copies in/out are delayed and
@ -133,7 +133,7 @@ class FSDPParamGroup:
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
mp_policy: MixedPrecisionPolicy,
offload_policy: OffloadPolicy,
):
) -> None:
self.modules = modules # permit ref cycle because 1:1 lifetime
param_module_infos = _get_param_module_infos(params, modules)
@ -248,7 +248,7 @@ class FSDPParamGroup:
)
self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
def lazy_init(self):
def lazy_init(self) -> None:
# Lazy init should be idempotent
# Users may change or register parameters after construction time.
# For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
@ -300,7 +300,7 @@ class FSDPParamGroup:
)
# Runtime #
def unshard(self, async_op: bool = False):
def unshard(self, async_op: bool = False) -> None:
if self._all_gather_result is not None: # already called, pending wait
return
if self.is_unsharded:
@ -344,7 +344,7 @@ class FSDPParamGroup:
self._all_gather_comm,
)
def wait_for_unshard(self):
def wait_for_unshard(self) -> None:
"""
1. In forward with implicit prefetching, to overlap the current copy-out
with the next all-gather, we save a reference to the current all-gather
@ -419,14 +419,14 @@ class FSDPParamGroup:
self._all_gather_result = None # free unless saved in `all_gather_state`
def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]):
def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]) -> None:
# Calling `unshard` before lazy init means streams are not initialized
if hasattr(self.comm_ctx, "all_gather_copy_in_stream") and event is not None:
self.comm_ctx.all_gather_copy_in_stream.wait_event(event)
if hasattr(self.comm_ctx, "all_gather_stream") and event is not None:
self.comm_ctx.all_gather_stream.wait_event(event)
def reshard(self):
def reshard(self) -> None:
if self._training_state == TrainingState.FORWARD:
if not self._reshard_after_forward:
return
@ -473,7 +473,7 @@ class FSDPParamGroup:
self.comm_ctx.post_forward_order.append(self)
self._post_forward_indices.append(post_forward_index)
def pre_backward(self, default_prefetch: bool, *unused: Any):
def pre_backward(self, default_prefetch: bool, *unused: Any) -> None:
if (
compiled_autograd_enabled()
and self._training_state == TrainingState.PRE_BACKWARD
@ -492,7 +492,7 @@ class FSDPParamGroup:
if default_prefetch and not compiled_autograd_enabled():
self._backward_prefetch()
def post_backward(self, *unused: Any):
def post_backward(self, *unused: Any) -> None:
# This method should be idempotent and safe to call even when this
# FSDP parameter group was not used in backward (should be a no-op)
if not compiled_autograd_enabled():
@ -601,7 +601,7 @@ class FSDPParamGroup:
all_reduce_input, all_reduce_event
)
def finalize_backward(self):
def finalize_backward(self) -> None:
self._wait_for_post_backward()
for fsdp_param in self.fsdp_params:
if fsdp_param.grad_offload_event is not None:
@ -618,7 +618,7 @@ class FSDPParamGroup:
self._all_gather_result = None
self._post_forward_indices.clear()
def _wait_for_post_backward(self):
def _wait_for_post_backward(self) -> None:
if self._post_reduce_event is not None:
self.device_handle.current_stream().wait_event(self._post_reduce_event)
self._post_reduce_event = None
@ -663,19 +663,19 @@ class FSDPParamGroup:
target_fsdp_param_group.unshard(async_op)
# Utilities #
def _to_sharded(self):
def _to_sharded(self) -> None:
if not self.is_sharded:
for fsdp_param in self.fsdp_params:
fsdp_param.to_sharded()
self._sharded_state = ShardedState.SHARDED
def _to_sharded_post_forward(self):
def _to_sharded_post_forward(self) -> None:
if not self.is_sharded_post_forward:
for fsdp_param in self.fsdp_params:
fsdp_param.to_sharded_post_forward()
self._sharded_state = ShardedState.SHARDED_POST_FORWARD
def _to_unsharded(self):
def _to_unsharded(self) -> None:
if not self.is_unsharded:
for fsdp_param in self.fsdp_params:
fsdp_param.to_unsharded()
@ -806,10 +806,10 @@ class FSDPParamGroup:
return f"{label} ({self._module_fqn})"
return label
def __repr__(self):
def __repr__(self) -> str:
return f"FSDPParamGroup(fqn={self._module_fqn})"
def _validate_no_meta_params(self):
def _validate_no_meta_params(self) -> None:
param_names_on_meta = [
fsdp_param._param_fqn
for fsdp_param in self.fsdp_params
@ -823,7 +823,7 @@ class FSDPParamGroup:
"call module.reset_parameters() on each module to initialize values."
)
def _validate_cpu_offload_params(self):
def _validate_cpu_offload_params(self) -> None:
if not isinstance(self.offload_policy, CPUOffloadPolicy):
return
fsdp_params_not_on_cpu = [
@ -873,7 +873,7 @@ def _get_param_module_infos(
class RegisterPostBackwardFunction(torch.autograd.Function):
@staticmethod
def _assert_not_tracing_fsdp():
def _assert_not_tracing_fsdp() -> None:
if compiled_autograd_enabled():
# TODO: Find a way to print the offending FSDP2 module.
msg = """\

View File

@ -347,7 +347,7 @@ class FSDPState(_State):
t.register_hook(self._pre_backward)
return output
def _register_root_post_backward_final_callback(self):
def _register_root_post_backward_final_callback(self) -> None:
if self._state_ctx.post_backward_final_callback_queued:
return
self._state_ctx.post_backward_final_callback_queued = True

View File

@ -632,7 +632,7 @@ def _init_param_handle_from_params(
state: _FSDPState,
params: list[nn.Parameter],
fully_sharded_module: nn.Module,
):
) -> None:
if len(params) == 0:
return
handle = FlatParamHandle(
@ -905,7 +905,7 @@ def _materialize_meta_module(
device_from_device_id: Optional[torch.device],
ignored_modules: set[nn.Module],
device_handle: _FSDPDeviceHandle,
):
) -> None:
# Run default meta device initialization
materialization_device = device_from_device_id or torch.device(
device_handle.current_device()
@ -1046,7 +1046,7 @@ def _move_states_to_device(
_warn_cpu_init()
def _warn_cpu_init():
def _warn_cpu_init() -> None:
warnings.warn(
"The passed-in `module` is on CPU and will thus have FSDP's sharding "
"initialization run on CPU, which may be slower than on GPU. We "

View File

@ -1051,7 +1051,7 @@ def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]:
"""
def module_fn(module, prefix, tree_level, flat_param_to_fqn):
def module_fn(module, prefix, tree_level, flat_param_to_fqn) -> None:
for param_name, param in _named_parameters_with_duplicates(
module, recurse=False
):
@ -2077,7 +2077,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]:
parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters.
"""
def module_fn(module, prefix, tree_level, fqn_to_param_info):
def module_fn(module, prefix, tree_level, fqn_to_param_info) -> None:
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
if fsdp_state is None:
return

View File

@ -143,7 +143,7 @@ def _lazy_init(
return state
def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module) -> None:
"""
Checks that all ``FlatParameter``s in ``module`` 's tree managed by
``state`` are on the expected device for *lazy initialization*.
@ -311,7 +311,7 @@ def _reshard(
state: _FSDPState,
handle: FlatParamHandle,
free_unsharded_flat_param: bool,
):
) -> None:
"""
Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
free the handle's padded unsharded flat parameter.
@ -703,7 +703,7 @@ def _post_backward_hook(
handle: FlatParamHandle,
flat_param,
*unused: Any,
):
) -> None:
"""
Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
@ -955,7 +955,7 @@ def _post_reduce_grad_callback(
handle: FlatParamHandle,
# Additional arguments needed for the callback logic
grad_to_offload: torch.Tensor,
):
) -> None:
"""
This callback captures any logic to run after the gradient reduction
finishes. Currently, this offloads the gradient to CPU if CPU offloading is
@ -970,7 +970,7 @@ def _offload_grad(
state: _FSDPState,
handle: FlatParamHandle,
grad_to_offload: torch.Tensor,
):
) -> None:
if not handle._offload_params:
return
# Offload the gradient to CPU to ensure parameters and gradients are on the
@ -992,7 +992,7 @@ def _offload_grad(
@no_type_check
def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):
def _post_backward_use_sharded_grad_views(handle: FlatParamHandle) -> None:
if not handle._use_orig_params:
return
# Since the handle's `FlatParameter` completed its gradient computation, we
@ -1031,7 +1031,7 @@ def _cast_grad_to_param_dtype(
state: _FSDPState,
sharded_grad: torch.Tensor,
param: FlatParameter,
):
) -> None:
"""
Casts ``sharded_grad`` back to the full parameter dtype so that the
optimizer step runs with that dtype. This performs an actual cast if
@ -1084,7 +1084,7 @@ def _low_precision_hook_enabled(state: _FSDPState) -> bool:
def _post_backward_final_callback(
state: _FSDPState,
module: nn.Module,
):
) -> None:
"""
This waits for the post-backward to finish and performs some final cleanup.
This runs at the end of the entire backward pass and should only be called
@ -1346,7 +1346,7 @@ def _register_post_forward_hook(
def _register_root_pre_forward_hook(
state: _FSDPState,
module: nn.Module,
):
) -> None:
"""
Registers root pre-forward hook on ``module``, which should be the local
FSDP root.
@ -1544,7 +1544,7 @@ def _wait_for_computation_stream(
computation_stream: torch.Stream,
unshard_stream: torch.Stream,
pre_unshard_stream: torch.Stream,
):
) -> None:
"""
Has the unshard and pre-unshard streams wait for the computation stream.
For example, this should be called in the FSDP root's pre-forward to
@ -1562,7 +1562,7 @@ def _wait_for_computation_stream(
def _reset_flat_param_grad_info_if_needed(
handles: list[FlatParamHandle],
):
) -> None:
"""
Clears the original parameters' gradients if needed. This method's CPU
overhead is minimal, so we may call it throughout FSDP methods, which serve

View File

@ -18,7 +18,7 @@ from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
def _get_remote_device_str(rank, device_type, num_devices_per_node):
def _get_remote_device_str(rank, device_type, num_devices_per_node) -> str:
if device_type.lower() == "cpu":
return f"rank:{rank}/{device_type}"
elif device_type.lower() == "hpu":

View File

@ -542,7 +542,7 @@ def _sharded_post_state_dict_hook(
with a unflattened, sharded parameter (a ShardedTensor).
"""
def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str):
def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str) -> None:
param = state_dict[fqn]
if not fsdp_state._state_dict_config._use_dtensor:
sharded_tensor = _ext_chunk_tensor(
@ -895,7 +895,7 @@ def _post_load_state_dict_hook(
SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")
def _register_all_state_dict_hooks(state: _FSDPState):
def _register_all_state_dict_hooks(state: _FSDPState) -> None:
"""
Registers pre-save, post-save, pre-load, and post-load state dict hooks.
"""

View File

@ -35,7 +35,7 @@ FLAT_PARAM = "_flat_param"
def _writeback_to_local_shard(
handle: FlatParamHandle,
writeback_grad: bool,
):
) -> None:
"""
For the handle, writes back the this rank's shard of the unsharded
flattened parameter to the sharded flattened parameter. If

View File

@ -30,7 +30,7 @@ def _auto_wrap(
ignored_params: set[nn.Parameter],
root_kwargs: dict[str, Any],
fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
):
) -> None:
"""
Auto wraps modules in ``root_module`` 's tree according to ``policy``
following a post-order traversal.
@ -102,7 +102,7 @@ def _auto_wrap(
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
def _check_nested_wrapping(root_module: nn.Module):
def _check_nested_wrapping(root_module: nn.Module) -> None:
for module_name, module in root_module.named_modules():
if _get_module_fsdp_state(module) is not None:
raise ValueError(
@ -113,7 +113,7 @@ def _check_nested_wrapping(root_module: nn.Module):
def _warn_on_overridden_mixed_precision(
overridden_module_classes: set[type[nn.Module]],
):
) -> None:
if len(overridden_module_classes) == 0:
return
warnings.warn(
@ -130,7 +130,7 @@ def _validate_frozen_params(
modules_to_wrap: set[nn.Module],
ignored_params: set[nn.Parameter],
use_orig_params: bool,
):
) -> None:
"""
This checks that, given ``modules_to_wrap``, each module would manage
parameters that are uniformly frozen or non-frozen. This uniformity

Some files were not shown because too many files have changed in this diff Show More