Add correct __all__ for torch.distributed and torch.cuda submodules (#85702)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85702
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/rohan-varma
This commit is contained in:
anjali411
2022-10-09 16:01:31 +00:00
committed by PyTorch MergeBot
parent d93b1b9c4e
commit e2a4dfa468
22 changed files with 34 additions and 155 deletions

View File

@ -141,20 +141,6 @@
"torch.backends": [
"contextmanager"
],
"torch.cpu.amp.autocast_mode": [
"Any"
],
"torch.cuda": [
"Any",
"Device",
"Dict",
"List",
"Optional",
"Set",
"Tuple",
"Union",
"classproperty"
],
"torch.cuda.comm": [
"broadcast",
"broadcast_coalesced",
@ -163,12 +149,6 @@
"scatter",
"gather"
],
"torch.cuda.amp.autocast_mode": [
"Any"
],
"torch.cuda.amp.common": [
"find_spec"
],
"torch.cuda.nccl": [
"init_rank",
"is_available",
@ -292,47 +272,7 @@
"torch.distributed.optim.utils": [
"Type"
],
"torch.distributed.pipeline.sync.checkpoint": [
"Checkpoint",
"Checkpointing",
"Context",
"Function",
"Recompute",
"ThreadLocal",
"checkpoint",
"enable_checkpointing",
"enable_recomputing",
"restore_rng_states",
"save_rng_states"
],
"torch.distributed.pipeline.sync.copy": [
"Context",
"Copy",
"Wait"
],
"torch.distributed.pipeline.sync.dependency": [
"Fork",
"Join",
"fork",
"join"
],
"torch.distributed.pipeline.sync.microbatch": [
"Batch",
"NoChunk",
"check",
"gather",
"scatter"
],
"torch.distributed.pipeline.sync.phony": [
"get_phony"
],
"torch.distributed.pipeline.sync.pipe": [
"BalanceError",
"PipeSequential",
"Pipeline",
"WithDevice"
],
"torch.distributed.pipeline.sync.pipeline": [
"Pipeline"
],
"torch.distributed.pipeline.sync.skip.layout": [
@ -356,25 +296,6 @@
"current_skip_tracker",
"use_skip_tracker"
],
"torch.distributed.pipeline.sync.stream": [
"CPUStreamType",
"as_cuda",
"current_stream",
"default_stream",
"get_device",
"is_cuda",
"new_stream",
"record_stream",
"use_device",
"use_stream",
"wait_stream"
],
"torch.distributed.pipeline.sync.worker": [
"Task",
"create_workers",
"spawn_workers",
"worker"
],
"torch.distributed.remote_device": [
"Optional",
"Union"
@ -395,69 +316,6 @@
"urlunparse"
],
"torch.distributed.rpc": [
"Any",
"Dict",
"Future",
"Generator",
"Generic",
"GenericWithOneTypeVar",
"PyRRef",
"RemoteProfilerManager",
"RpcAgent",
"RpcBackendOptions",
"Set",
"Store",
"TensorPipeAgent",
"Tuple",
"TypeVar",
"WorkerInfo",
"enable_gil_profiling",
"get_rpc_timeout",
"method",
"timedelta",
"urlparse"
],
"torch.distributed.rpc.api": [
"Any",
"Dict",
"Future",
"Generic",
"GenericWithOneTypeVar",
"PyRRef",
"PythonUDF",
"RPCExecMode",
"RemoteProfilerManager",
"Set",
"TypeVar",
"WorkerInfo",
"get_rpc_timeout",
"method"
],
"torch.distributed.rpc.backend_registry": [
"Dict",
"List",
"Set",
"Tuple"
],
"torch.distributed.rpc.constants": [
"timedelta"
],
"torch.distributed.rpc.internal": [
"Enum"
],
"torch.distributed.rpc.options": [
"DeviceType",
"Dict",
"List",
"Optional",
"Union"
],
"torch.distributions.utils": [
"Any",
"Dict",
"Number",
"is_tensor_like",
"update_wrapper"
],
"torch.fft": [
"Tensor",

View File

@ -1,6 +1,8 @@
import torch
from typing import Any
__all__ = ["autocast"]
class autocast(torch.amp.autocast_mode.autocast):
r"""
See :class:`torch.autocast`.

View File

@ -834,7 +834,7 @@ __all__ = [
'IntStorage', 'IntTensor',
'LongStorage', 'LongTensor',
'ShortStorage', 'ShortTensor',
'CUDAGraph', 'CudaError', 'DeferredCudaCallError', 'Device', 'Event', 'ExternalStream', 'OutOfMemoryError',
'CUDAGraph', 'CudaError', 'DeferredCudaCallError', 'Event', 'ExternalStream', 'OutOfMemoryError',
'Stream', 'StreamContext', 'amp', 'caching_allocator_alloc', 'caching_allocator_delete', 'can_device_access_peer',
'check_error', 'cudaStatus', 'cudart', 'current_blas_handle', 'current_device', 'current_stream', 'default_generators',
'default_stream', 'device', 'device_count', 'device_of', 'empty_cache', 'get_arch_list', 'get_device_capability',

View File

@ -9,6 +9,7 @@ except ModuleNotFoundError:
from torch._six import string_classes
from typing import Any
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
class autocast(torch.amp.autocast_mode.autocast):
r"""

View File

@ -1,6 +1,7 @@
import torch
from importlib.util import find_spec
__all__ = ["amp_definitely_not_available"]
def amp_definitely_not_available():
return not (torch.cuda.is_available() or find_spec('torch_xla'))

View File

@ -4,7 +4,6 @@ from enum import Enum
import torch
def is_available() -> bool:
"""
Returns ``True`` if the distributed package is available. Otherwise,

View File

@ -47,7 +47,9 @@ from .dependency import fork, join
from .microbatch import Batch
from .phony import get_phony
__all__ = ["is_checkpointing", "is_recomputing"]
__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing",
"enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states",
"restore_rng_states", "Checkpoint", "Recompute"]
Tensors = Sequence[Tensor]

View File

@ -15,7 +15,7 @@ from torch import Tensor
from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream
__all__: List[str] = []
__all__: List[str] = ["Context", "Copy", "Wait"]
Tensors = Sequence[Tensor]

View File

@ -12,7 +12,7 @@ from torch import Tensor
from .phony import get_phony
__all__: List[str] = []
__all__: List[str] = ["fork", "Fork", "join", "Join"]
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:

View File

@ -12,7 +12,7 @@ import torch
from torch import Tensor
import torch.cuda.comm
__all__: List[str] = []
__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"]
Tensors = Sequence[Tensor]

View File

@ -12,7 +12,7 @@ from torch import Tensor
from .stream import default_stream, use_stream
__all__: List[str] = []
__all__: List[str] = ["get_phony"]
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}

View File

@ -21,7 +21,7 @@ from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream
__all__ = ["Pipe"]
__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"]
Device = Union[torch.device, int, str]

View File

@ -23,7 +23,7 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .stream import AbstractStream, current_stream, use_device
from .worker import Task, create_workers
__all__: List[str] = []
__all__: List[str] = ["Pipeline"]
Tensors = Sequence[Tensor]

View File

@ -12,7 +12,9 @@ from typing import Generator, List, Union, cast
import torch
__all__: List[str] = []
__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream",
"use_device", "use_stream", "get_device", "wait_stream", "record_stream",
"is_cuda", "as_cuda"]
class CPUStreamType:

View File

@ -1,6 +1,8 @@
from torch import nn
from typing import List
__all__ = ["partition_model"]
def partition_model(
module: nn.Sequential,
balance: List[int],

View File

@ -17,7 +17,7 @@ import torch
from .microbatch import Batch
from .stream import AbstractStream, use_device, use_stream
__all__: List[str] = []
__all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"]
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]

View File

@ -9,13 +9,13 @@ from urllib.parse import urlparse
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
_init_counter = 0
_init_counter_lock = threading.Lock()
__all__ = ["is_available"]
def is_available():
return hasattr(torch._C, "_rpc_init")
@ -77,6 +77,9 @@ if is_available():
rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
__all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"]
__all__ = __all__ + api.__all__ + backend_registry.__all__
def init_rpc(
name,
backend=None,

View File

@ -11,6 +11,9 @@ from ._utils import _group_membership_management, _update_group_membership
from . import api
from . import constants as rpc_constants
__all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend",
"BackendValue", "BackendType"]
BackendValue = collections.namedtuple(
"BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
)

View File

@ -1,5 +1,5 @@
from datetime import timedelta
from typing import List
from torch._C._distributed_rpc import (
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_WORKER_THREADS,
@ -20,3 +20,5 @@ DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS
DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1)
# Value indicating that timeout is not set for RPC call, and the default should be used.
UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT
__all__: List[str] = []

View File

@ -11,6 +11,7 @@ import torch
import torch.distributed as dist
from torch._C._distributed_rpc import _get_current_rpc_agent
__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"]
# Thread local tensor tables to store tensors while pickling torch.Tensor
# objects

View File

@ -7,6 +7,7 @@ from . import constants as rpc_contants
DeviceType = Union[int, str, torch.device]
__all__ = ["TensorPipeRpcBackendOptions"]
def _to_device(device: DeviceType) -> torch.device:
device = torch.device(device)

View File

@ -7,6 +7,8 @@ from torch.overrides import is_tensor_like
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
__all__ = ["broadcast_all", "logits_to_probs", "clamp_probs", "probs_to_logits", "lazy_property",
"tril_matrix_to_vec", "vec_to_tril_matrix"]
def broadcast_all(*values):
r"""