Add __all__ to torch.distributed, futures, fx, nn, package, benchmark submodules (#80520)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80520
Approved by: https://github.com/rohan-varma
This commit is contained in:
anjali411
2022-07-07 18:21:06 +00:00
committed by PyTorch MergeBot
parent 81ca2ff353
commit 4bf076e964
21 changed files with 48 additions and 141 deletions

View File

@ -459,12 +459,6 @@
"torch.distributed.elastic.metrics": [ "torch.distributed.elastic.metrics": [
"Optional" "Optional"
], ],
"torch.distributed.elastic.metrics.api": [
"Dict",
"Optional",
"namedtuple",
"wraps"
],
"torch.distributed.elastic.multiprocessing": [ "torch.distributed.elastic.multiprocessing": [
"Callable", "Callable",
"Dict", "Dict",
@ -508,12 +502,6 @@
"get_logger", "get_logger",
"wraps" "wraps"
], ],
"torch.distributed.elastic.multiprocessing.errors.error_handler": [
"Optional"
],
"torch.distributed.elastic.multiprocessing.errors.handlers": [
"ErrorHandler"
],
"torch.distributed.elastic.multiprocessing.redirects": [ "torch.distributed.elastic.multiprocessing.redirects": [
"contextmanager", "contextmanager",
"partial", "partial",
@ -920,15 +908,7 @@
"svd_lowrank" "svd_lowrank"
], ],
"torch.futures": [ "torch.futures": [
"Callable", "Future"
"Future",
"Generic",
"List",
"Optional",
"Type",
"TypeVar",
"Union",
"cast"
], ],
"torch.fx": [ "torch.fx": [
"ProxyableClassMeta", "ProxyableClassMeta",
@ -1033,25 +1013,6 @@
"Tuple", "Tuple",
"compatibility" "compatibility"
], ],
"torch.fx.node": [
"Any",
"ArgsKwargsPair",
"Argument",
"BaseArgumentTypes",
"Callable",
"Dict",
"List",
"Optional",
"Set",
"Target",
"Tuple",
"Union",
"compatibility",
"immutable_dict",
"immutable_list",
"normalize_function",
"normalize_module"
],
"torch.fx.operator_schemas": [ "torch.fx.operator_schemas": [
"Any", "Any",
"Callable", "Callable",
@ -1072,53 +1033,6 @@
"chain", "chain",
"compatibility" "compatibility"
], ],
"torch.fx.passes.graph_manipulation": [
"Any",
"Argument",
"Dict",
"Graph",
"GraphModule",
"List",
"NamedTuple",
"Node",
"Optional",
"ShapeProp",
"Target",
"Tuple",
"compatibility",
"lift_lowering_attrs_to_nodes",
"map_aggregate",
"map_arg"
],
"torch.fx.passes.param_fetch": [
"Any",
"Callable",
"Dict",
"GraphModule",
"List",
"Tuple",
"Type",
"compatibility"
],
"torch.fx.passes.shape_prop": [
"Any",
"Dict",
"NamedTuple",
"Node",
"Optional",
"Tuple",
"compatibility",
"map_aggregate"
],
"torch.fx.passes.split_module": [
"Any",
"Callable",
"Dict",
"GraphModule",
"List",
"Optional",
"compatibility"
],
"torch.fx.proxy": [ "torch.fx.proxy": [
"assert_fn" "assert_fn"
], ],
@ -1412,11 +1326,6 @@
"torch.nn.intrinsic.modules": [ "torch.nn.intrinsic.modules": [
"_FusedModule" "_FusedModule"
], ],
"torch.nn.intrinsic.qat.modules.conv_fused": [
"Parameter",
"TypeVar",
"fuse_conv_bn_weights"
],
"torch.nn.intrinsic.qat.modules.linear_fused": [ "torch.nn.intrinsic.qat.modules.linear_fused": [
"Parameter", "Parameter",
"fuse_linear_bn_weights" "fuse_linear_bn_weights"
@ -1436,14 +1345,6 @@
"torch.nn.parallel.comm": [ "torch.nn.parallel.comm": [
"List" "List"
], ],
"torch.nn.parallel.data_parallel": [
"Module",
"chain",
"gather",
"parallel_apply",
"replicate",
"scatter_kwargs"
],
"torch.nn.parallel.parallel_apply": [ "torch.nn.parallel.parallel_apply": [
"ExceptionWrapper", "ExceptionWrapper",
"autocast" "autocast"
@ -1452,8 +1353,7 @@
"OrderedDict" "OrderedDict"
], ],
"torch.nn.parallel.scatter_gather": [ "torch.nn.parallel.scatter_gather": [
"Gather", "is_namedtuple"
"Scatter"
], ],
"torch.nn.parameter": [ "torch.nn.parameter": [
"OrderedDict" "OrderedDict"
@ -1514,26 +1414,6 @@
"Iterable", "Iterable",
"Optional" "Optional"
], ],
"torch.nn.utils.parametrizations": [
"Enum",
"Module",
"Optional",
"Tensor",
"auto"
],
"torch.nn.utils.parametrize": [
"Dict",
"Module",
"ModuleDict",
"ModuleList",
"Optional",
"Parameter",
"Sequence",
"Tensor",
"Tuple",
"Union",
"contextmanager"
],
"torch.onnx": [ "torch.onnx": [
"Dict", "Dict",
"OperatorExportTypes", "OperatorExportTypes",
@ -1550,11 +1430,6 @@
"has_torch_function", "has_torch_function",
"push_torch_function_mode" "push_torch_function_mode"
], ],
"torch.package.analyze.find_first_use_of_broken_modules": [
"Dict",
"List",
"PackagingError"
],
"torch.package.analyze.is_from_package": [ "torch.package.analyze.is_from_package": [
"Any", "Any",
"ModuleType", "ModuleType",
@ -1919,10 +1794,6 @@
"rand", "rand",
"randn" "randn"
], ],
"torch.torch_version": [
"Any",
"Iterable"
],
"torch.types": [ "torch.types": [
"Any", "Any",
"Device", "Device",
@ -1937,14 +1808,6 @@
"enable_minidumps", "enable_minidumps",
"enable_minidumps_on_exceptions" "enable_minidumps_on_exceptions"
], ],
"torch.utils.benchmark.utils.common": [
"_make_temp_dir",
"ordered_unique",
"select_unit",
"set_torch_threads",
"trim_sigfig",
"unit_to_english"
],
"torch.utils.benchmark.utils.compare": [ "torch.utils.benchmark.utils.compare": [
"Colorize", "Colorize",
"Table", "Table",

View File

@ -13,6 +13,10 @@ from collections import namedtuple
from functools import wraps from functools import wraps
from typing import Dict, Optional from typing import Dict, Optional
__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream',
'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms',
'MetricData']
MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])

View File

@ -14,6 +14,7 @@ import traceback
import warnings import warnings
from typing import Optional from typing import Optional
__all__ = ['ErrorHandler']
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -10,6 +10,7 @@
from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
__all__ = ['get_error_handler']
def get_error_handler(): def get_error_handler():
return ErrorHandler() return ErrorHandler()

View File

@ -4,6 +4,8 @@ from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union
import torch import torch
__all__ = ['Future', 'collect_all', 'wait_all']
T = TypeVar("T") T = TypeVar("T")
S = TypeVar("S") S = TypeVar("S")

View File

@ -11,6 +11,8 @@ from torch.fx.operator_schemas import normalize_function, normalize_module, Args
if TYPE_CHECKING: if TYPE_CHECKING:
from .graph import Graph from .graph import Graph
__all__ = ['Node', 'map_arg', 'map_aggregate']
BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
torch.Tensor, torch.device, torch.memory_format, torch.layout] torch.Tensor, torch.device, torch.memory_format, torch.layout]
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]

View File

@ -9,6 +9,7 @@ from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from itertools import chain from itertools import chain
__all__ = ['FxGraphDrawer']
try: try:
import pydot import pydot
HAS_PYDOT = True HAS_PYDOT = True

View File

@ -11,6 +11,9 @@ from torch.fx.node import (
) )
from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.shape_prop import ShapeProp
__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
'get_size_of_node']
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def replace_target_nodes_with( def replace_target_nodes_with(
fx_module: GraphModule, fx_module: GraphModule,

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
# Matching method matches the attribute name of current version to the attribute name of `target_version` # Matching method matches the attribute name of current version to the attribute name of `target_version`
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)

View File

@ -6,6 +6,7 @@ from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional, Dict from typing import Any, Tuple, NamedTuple, Optional, Dict
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
__all__ = ['TensorMetadata', 'ShapeProp']
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple): class TensorMetadata(NamedTuple):

View File

@ -4,6 +4,8 @@ from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
import inspect import inspect
__all__ = ['Partition', 'split_module']
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class Partition: class Partition:
def __init__(self, name: str): def __init__(self, name: str):

View File

@ -10,6 +10,8 @@ from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import TypeVar from typing import TypeVar
__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d',
'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats']
_BN_CLASS_MAP = { _BN_CLASS_MAP = {
1: nn.BatchNorm1d, 1: nn.BatchNorm1d,
2: nn.BatchNorm2d, 2: nn.BatchNorm2d,

View File

@ -13,6 +13,8 @@ from torch._utils import (
_get_devices_properties _get_devices_properties
) )
__all__ = ['DataParallel', 'data_parallel']
def _check_balance(device_ids): def _check_balance(device_ids):
imbalance_warn = """ imbalance_warn = """
There is an imbalance between your GPUs. You may want to exclude GPU {} which There is an imbalance between your GPUs. You may want to exclude GPU {} which

View File

@ -1,7 +1,15 @@
import torch import torch
from ._functions import Scatter, Gather from ._functions import Scatter, Gather
import warnings
__all__ = ['scatter', 'scatter_kwargs', 'gather']
def is_namedtuple(obj): def is_namedtuple(obj):
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
return _is_namedtuple(obj)
def _is_namedtuple(obj):
# Check if type was created from collections.namedtuple or a typing.NamedTuple. # Check if type was created from collections.namedtuple or a typing.NamedTuple.
return ( return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")

View File

@ -8,6 +8,7 @@ from .. import functional as F
from typing import Optional from typing import Optional
__all__ = ['orthogonal', 'spectral_norm']
def _is_orthogonal(Q, eps=None): def _is_orthogonal(Q, eps=None):
n, k = Q.size(-2), Q.size(-1) n, k = Q.size(-2), Q.size(-1)

View File

@ -7,6 +7,9 @@ import collections
from contextlib import contextmanager from contextlib import contextmanager
from typing import Union, Optional, Dict, Tuple, Sequence from typing import Union, Optional, Dict, Tuple, Sequence
__all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations',
'type_before_parametrizations', 'transfer_parametrizations_and_params']
_cache_enabled = 0 _cache_enabled = 0
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {} _cache: Dict[Tuple[int, str], Optional[Tensor]] = {}

View File

@ -2,6 +2,8 @@ from typing import Dict, List
from ..package_exporter import PackagingError from ..package_exporter import PackagingError
__all__ = ["find_first_use_of_broken_modules"]
def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]: def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]:
""" """

View File

@ -1,6 +1,8 @@
from typing import Any, Iterable from typing import Any, Iterable
from .version import __version__ as internal_version from .version import __version__ as internal_version
__all__ = ['TorchVersion', 'Version', 'InvalidVersion']
class _LazyImport: class _LazyImport:
"""Wraps around classes lazy imported from packaging.version """Wraps around classes lazy imported from packaging.version
Output of the function v in following snippets are identical: Output of the function v in following snippets are identical:

View File

@ -14,7 +14,7 @@ import uuid
import torch import torch
__all__ = ["TaskSpec", "Measurement", "_make_temp_dir"] __all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"]
_MAX_SIGNIFICANT_FIGURES = 4 _MAX_SIGNIFICANT_FIGURES = 4

View File

@ -7,7 +7,7 @@ from typing import DefaultDict, List, Optional, Tuple
from torch.utils.benchmark.utils import common from torch.utils.benchmark.utils import common
from torch import tensor as _tensor from torch import tensor as _tensor
__all__ = ["Compare"] __all__ = ["Colorize", "Compare"]
BEST = "\033[92m" BEST = "\033[92m"
GOOD = "\033[34m" GOOD = "\033[34m"

View File

@ -46,6 +46,12 @@ HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
to their actual types.""" to their actual types."""
PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_caffe2_gpu_file',
'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify']
class InputError(Exception): class InputError(Exception):
# Exception raised for errors in the input. # Exception raised for errors in the input.