Add __all__ to torch.distributed and tensorboard submodules (#80444)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80444
Approved by: https://github.com/rohan-varma
This commit is contained in:
PyTorch MergeBot
2022-06-28 13:22:13 +00:00
parent c8943f831e
commit 14a7cf79c1
7 changed files with 8 additions and 98 deletions

View File

@ -543,44 +543,6 @@
"torch.distributed.elastic.rendezvous.dynamic_rendezvous": [
"get_method_name"
],
"torch.distributed.elastic.rendezvous.registry": [
"RendezvousHandler",
"RendezvousParameters",
"create_handler"
],
"torch.distributed.elastic.rendezvous.utils": [
"Any",
"Callable",
"Dict",
"Event",
"Optional",
"Thread",
"Tuple",
"Union",
"timedelta"
],
"torch.distributed.elastic.timer.api": [
"Any",
"Dict",
"List",
"Optional",
"Set",
"contextmanager",
"getframeinfo",
"stack"
],
"torch.distributed.elastic.timer.local_timer": [
"Any",
"Dict",
"Empty",
"List",
"RequestQueue",
"Set",
"TimerClient",
"TimerRequest",
"TimerServer",
"Tuple"
],
"torch.distributed.elastic.utils.api": [
"Any",
"List",
@ -613,34 +575,6 @@
"Union",
"accumulate"
],
"torch.distributed.fsdp.fully_sharded_data_parallel": [
"Any",
"Callable",
"Dict",
"Enum",
"FlatParameter",
"FlattenParamsWrapper",
"Generator",
"Iterable",
"Iterator",
"List",
"Mapping",
"NamedTuple",
"Optional",
"Parameter",
"ProcessGroup",
"Set",
"Shard",
"ShardedTensor",
"Tuple",
"Union",
"Variable",
"auto",
"cast",
"contextmanager",
"dataclass",
"init_from_local_shards"
],
"torch.distributed.fsdp.utils": [
"Any",
"Callable",
@ -662,25 +596,6 @@
"Type",
"cast"
],
"torch.distributed.launcher.api": [
"Any",
"Callable",
"ChildFailedError",
"Dict",
"List",
"LocalElasticAgent",
"Optional",
"RendezvousParameters",
"SignalException",
"Std",
"Tuple",
"Union",
"WorkerSpec",
"dataclass",
"field",
"get_logger",
"parse_rendezvous_endpoint"
],
"torch.distributed.nn": [
"Function",
"ReduceOp",
@ -2114,19 +2029,6 @@
"IO",
"Union"
],
"torch.utils.tensorboard.summary": [
"HistogramProto",
"Optional",
"PrCurvePluginData",
"Summary",
"SummaryMetadata",
"TensorProto",
"TensorShapeProto",
"TextPluginData",
"convert_to_HWC",
"make_np",
"range"
],
"torch": [
"BFloat16Storage",
"BFloat16Tensor",

View File

@ -8,6 +8,7 @@ from .api import RendezvousHandler, RendezvousParameters
from .api import rendezvous_handler_registry as handler_registry
from .dynamic_rendezvous import create_handler
__all__ = ['get_rendezvous_handler']
def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import static_tcp_rendezvous

View File

@ -14,6 +14,7 @@ from datetime import timedelta
from threading import Event, Thread
from typing import Any, Callable, Dict, Optional, Tuple, Union
__all__ = ['parse_rendezvous_endpoint']
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
"""Extracts key-value pairs from a rendezvous configuration string.

View File

@ -11,6 +11,7 @@ from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set
__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires']
class TimerRequest:
"""

View File

@ -13,6 +13,7 @@ from typing import Any, Dict, List, Set, Tuple
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer']
class LocalTimerClient(TimerClient):
"""

View File

@ -20,6 +20,7 @@ from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
logger = get_logger()

View File

@ -20,6 +20,9 @@ from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
from ._convert_np import make_np
from ._utils import _prepare_video, convert_to_HWC
__all__ = ['hparams', 'scalar', 'histogram_raw', 'histogram', 'make_histogram', 'image', 'image_boxes', 'draw_boxes',
'make_image', 'video', 'make_video', 'audio', 'custom_scalars', 'text', 'pr_curve_raw', 'pr_curve', 'compute_curve',
'mesh']
logger = logging.getLogger(__name__)