diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index a1d1ffd2fc87..d9cc6d12785c 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -1,7 +1,21 @@ # mypy: allow-untyped-defs +import sys from enum import Enum from functools import partial + +# To suppress FutureWarning from partial since 3.13 +if sys.version_info >= (3, 13): + from enum import member + + def _enum_member(x): + return member(x) +else: + + def _enum_member(x): + return x + + import torch.distributed as dist from . import ( @@ -51,45 +65,61 @@ class DDPCommHookType(Enum): ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``. """ - ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook) - FP16_COMPRESS = partial( - _ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook + ALLREDUCE = _enum_member( + partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook) ) - BF16_COMPRESS = partial( - _ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook + FP16_COMPRESS = _enum_member( + partial(_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook) ) - QUANTIZE_PER_TENSOR = partial( - _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook + BF16_COMPRESS = _enum_member( + partial(_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook) ) - QUANTIZE_PER_CHANNEL = partial( - _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook + QUANTIZE_PER_TENSOR = _enum_member( + partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook + ) ) - POWER_SGD = partial( - _powerSGD_comm_hook_wrapper, - comm_hook=powerSGD.powerSGD_hook, - matrix_approximation_rank=1, + QUANTIZE_PER_CHANNEL = _enum_member( + partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook + ) + ) + POWER_SGD = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=1, + ) ) # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version, # but it runs slower and consumes more memory. - POWER_SGD_RANK2 = partial( - _powerSGD_comm_hook_wrapper, - comm_hook=powerSGD.powerSGD_hook, - matrix_approximation_rank=2, + POWER_SGD_RANK2 = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=2, + ) ) # Batching can lead to a faster training at the cost of accuracy. - BATCHED_POWER_SGD = partial( - _powerSGD_comm_hook_wrapper, - comm_hook=powerSGD.batched_powerSGD_hook, - matrix_approximation_rank=1, + BATCHED_POWER_SGD = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=1, + ) ) - BATCHED_POWER_SGD_RANK2 = partial( - _powerSGD_comm_hook_wrapper, - comm_hook=powerSGD.batched_powerSGD_hook, - matrix_approximation_rank=2, + BATCHED_POWER_SGD_RANK2 = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=2, + ) ) - NOOP = partial( - _ddp_comm_hook_wrapper, - comm_hook=debugging.noop_hook, + NOOP = _enum_member( + partial( + _ddp_comm_hook_wrapper, + comm_hook=debugging.noop_hook, + ) )