Suppress FutureWarnings in torch.distributed.algorithms.ddp_comm_hooks (#163939)

Fixes #163938

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163939
Approved by: https://github.com/cyyever, https://github.com/kwen2501
This commit is contained in:
Xuanteng Huang
2025-10-01 07:51:08 +00:00
committed by PyTorch MergeBot
parent 590224f83c
commit 12d4cb0122

View File

@ -1,7 +1,21 @@
# mypy: allow-untyped-defs
import sys
from enum import Enum
from functools import partial
# To suppress FutureWarning from partial since 3.13
if sys.version_info >= (3, 13):
from enum import member
def _enum_member(x):
return member(x)
else:
def _enum_member(x):
return x
import torch.distributed as dist
from . import (
@ -51,45 +65,61 @@ class DDPCommHookType(Enum):
``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``.
"""
ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
FP16_COMPRESS = partial(
_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook
ALLREDUCE = _enum_member(
partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
)
BF16_COMPRESS = partial(
_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook
FP16_COMPRESS = _enum_member(
partial(_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook)
)
QUANTIZE_PER_TENSOR = partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
BF16_COMPRESS = _enum_member(
partial(_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook)
)
QUANTIZE_PER_CHANNEL = partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
QUANTIZE_PER_TENSOR = _enum_member(
partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
)
)
POWER_SGD = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=1,
QUANTIZE_PER_CHANNEL = _enum_member(
partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
)
)
POWER_SGD = _enum_member(
partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=1,
)
)
# Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version,
# but it runs slower and consumes more memory.
POWER_SGD_RANK2 = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=2,
POWER_SGD_RANK2 = _enum_member(
partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=2,
)
)
# Batching can lead to a faster training at the cost of accuracy.
BATCHED_POWER_SGD = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=1,
BATCHED_POWER_SGD = _enum_member(
partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=1,
)
)
BATCHED_POWER_SGD_RANK2 = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=2,
BATCHED_POWER_SGD_RANK2 = _enum_member(
partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=2,
)
)
NOOP = partial(
_ddp_comm_hook_wrapper,
comm_hook=debugging.noop_hook,
NOOP = _enum_member(
partial(
_ddp_comm_hook_wrapper,
comm_hook=debugging.noop_hook,
)
)