mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Suppress FutureWarning
s in torch.distributed.algorithms.ddp_comm_hooks
(#163939)
Fixes #163938 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163939 Approved by: https://github.com/cyyever, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
590224f83c
commit
12d4cb0122
@ -1,7 +1,21 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import sys
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
|
||||
|
||||
# To suppress FutureWarning from partial since 3.13
|
||||
if sys.version_info >= (3, 13):
|
||||
from enum import member
|
||||
|
||||
def _enum_member(x):
|
||||
return member(x)
|
||||
else:
|
||||
|
||||
def _enum_member(x):
|
||||
return x
|
||||
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from . import (
|
||||
@ -51,45 +65,61 @@ class DDPCommHookType(Enum):
|
||||
``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``.
|
||||
"""
|
||||
|
||||
ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
|
||||
FP16_COMPRESS = partial(
|
||||
_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook
|
||||
ALLREDUCE = _enum_member(
|
||||
partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
|
||||
)
|
||||
BF16_COMPRESS = partial(
|
||||
_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook
|
||||
FP16_COMPRESS = _enum_member(
|
||||
partial(_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook)
|
||||
)
|
||||
QUANTIZE_PER_TENSOR = partial(
|
||||
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
|
||||
BF16_COMPRESS = _enum_member(
|
||||
partial(_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook)
|
||||
)
|
||||
QUANTIZE_PER_CHANNEL = partial(
|
||||
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
|
||||
QUANTIZE_PER_TENSOR = _enum_member(
|
||||
partial(
|
||||
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
|
||||
)
|
||||
)
|
||||
POWER_SGD = partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.powerSGD_hook,
|
||||
matrix_approximation_rank=1,
|
||||
QUANTIZE_PER_CHANNEL = _enum_member(
|
||||
partial(
|
||||
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
|
||||
)
|
||||
)
|
||||
POWER_SGD = _enum_member(
|
||||
partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.powerSGD_hook,
|
||||
matrix_approximation_rank=1,
|
||||
)
|
||||
)
|
||||
# Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version,
|
||||
# but it runs slower and consumes more memory.
|
||||
POWER_SGD_RANK2 = partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.powerSGD_hook,
|
||||
matrix_approximation_rank=2,
|
||||
POWER_SGD_RANK2 = _enum_member(
|
||||
partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.powerSGD_hook,
|
||||
matrix_approximation_rank=2,
|
||||
)
|
||||
)
|
||||
# Batching can lead to a faster training at the cost of accuracy.
|
||||
BATCHED_POWER_SGD = partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.batched_powerSGD_hook,
|
||||
matrix_approximation_rank=1,
|
||||
BATCHED_POWER_SGD = _enum_member(
|
||||
partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.batched_powerSGD_hook,
|
||||
matrix_approximation_rank=1,
|
||||
)
|
||||
)
|
||||
BATCHED_POWER_SGD_RANK2 = partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.batched_powerSGD_hook,
|
||||
matrix_approximation_rank=2,
|
||||
BATCHED_POWER_SGD_RANK2 = _enum_member(
|
||||
partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.batched_powerSGD_hook,
|
||||
matrix_approximation_rank=2,
|
||||
)
|
||||
)
|
||||
NOOP = partial(
|
||||
_ddp_comm_hook_wrapper,
|
||||
comm_hook=debugging.noop_hook,
|
||||
NOOP = _enum_member(
|
||||
partial(
|
||||
_ddp_comm_hook_wrapper,
|
||||
comm_hook=debugging.noop_hook,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user