[BE][Easy] enable UFMT for torch/nn/parallel (#128596)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128596
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Xuehai Pan
2024-06-17 03:58:18 +08:00
committed by PyTorch MergeBot
parent bfad0aee44
commit dff6342a0b
13 changed files with 233 additions and 122 deletions

View File

@ -16,10 +16,15 @@ from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
import torch
import torch.distributed as dist
from torch._utils import _get_device_index
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.utils._pytree import tree_flatten, tree_unflatten
from ..modules import Module
from .scatter_gather import gather, scatter_kwargs
RPC_AVAILABLE = False
if dist.is_available():
from torch.distributed.distributed_c10d import (
@ -35,15 +40,10 @@ if dist.is_available():
_to_kwargs,
_verify_param_shape_across_processes,
)
if torch.distributed.rpc.is_available():
if dist.rpc.is_available():
RPC_AVAILABLE = True
from torch.distributed.rpc import RRef
from torch._utils import _get_device_index
from ..modules import Module
from .scatter_gather import gather, scatter_kwargs # noqa: F401
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle