mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add __all__ to torch.distributed, futures, fx, nn, package, benchmark submodules (#80520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80520 Approved by: https://github.com/rohan-varma
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							81ca2ff353
						
					
				
				
					commit
					4bf076e964
				
			| @ -459,12 +459,6 @@ | ||||
|   "torch.distributed.elastic.metrics": [ | ||||
|     "Optional" | ||||
|   ], | ||||
|   "torch.distributed.elastic.metrics.api": [ | ||||
|     "Dict", | ||||
|     "Optional", | ||||
|     "namedtuple", | ||||
|     "wraps" | ||||
|   ], | ||||
|   "torch.distributed.elastic.multiprocessing": [ | ||||
|     "Callable", | ||||
|     "Dict", | ||||
| @ -508,12 +502,6 @@ | ||||
|     "get_logger", | ||||
|     "wraps" | ||||
|   ], | ||||
|   "torch.distributed.elastic.multiprocessing.errors.error_handler": [ | ||||
|     "Optional" | ||||
|   ], | ||||
|   "torch.distributed.elastic.multiprocessing.errors.handlers": [ | ||||
|     "ErrorHandler" | ||||
|   ], | ||||
|   "torch.distributed.elastic.multiprocessing.redirects": [ | ||||
|     "contextmanager", | ||||
|     "partial", | ||||
| @ -920,15 +908,7 @@ | ||||
|     "svd_lowrank" | ||||
|   ], | ||||
|   "torch.futures": [ | ||||
|     "Callable", | ||||
|     "Future", | ||||
|     "Generic", | ||||
|     "List", | ||||
|     "Optional", | ||||
|     "Type", | ||||
|     "TypeVar", | ||||
|     "Union", | ||||
|     "cast" | ||||
|     "Future" | ||||
|   ], | ||||
|   "torch.fx": [ | ||||
|     "ProxyableClassMeta", | ||||
| @ -1033,25 +1013,6 @@ | ||||
|     "Tuple", | ||||
|     "compatibility" | ||||
|   ], | ||||
|   "torch.fx.node": [ | ||||
|     "Any", | ||||
|     "ArgsKwargsPair", | ||||
|     "Argument", | ||||
|     "BaseArgumentTypes", | ||||
|     "Callable", | ||||
|     "Dict", | ||||
|     "List", | ||||
|     "Optional", | ||||
|     "Set", | ||||
|     "Target", | ||||
|     "Tuple", | ||||
|     "Union", | ||||
|     "compatibility", | ||||
|     "immutable_dict", | ||||
|     "immutable_list", | ||||
|     "normalize_function", | ||||
|     "normalize_module" | ||||
|   ], | ||||
|   "torch.fx.operator_schemas": [ | ||||
|     "Any", | ||||
|     "Callable", | ||||
| @ -1072,53 +1033,6 @@ | ||||
|     "chain", | ||||
|     "compatibility" | ||||
|   ], | ||||
|   "torch.fx.passes.graph_manipulation": [ | ||||
|     "Any", | ||||
|     "Argument", | ||||
|     "Dict", | ||||
|     "Graph", | ||||
|     "GraphModule", | ||||
|     "List", | ||||
|     "NamedTuple", | ||||
|     "Node", | ||||
|     "Optional", | ||||
|     "ShapeProp", | ||||
|     "Target", | ||||
|     "Tuple", | ||||
|     "compatibility", | ||||
|     "lift_lowering_attrs_to_nodes", | ||||
|     "map_aggregate", | ||||
|     "map_arg" | ||||
|   ], | ||||
|   "torch.fx.passes.param_fetch": [ | ||||
|     "Any", | ||||
|     "Callable", | ||||
|     "Dict", | ||||
|     "GraphModule", | ||||
|     "List", | ||||
|     "Tuple", | ||||
|     "Type", | ||||
|     "compatibility" | ||||
|   ], | ||||
|   "torch.fx.passes.shape_prop": [ | ||||
|     "Any", | ||||
|     "Dict", | ||||
|     "NamedTuple", | ||||
|     "Node", | ||||
|     "Optional", | ||||
|     "Tuple", | ||||
|     "compatibility", | ||||
|     "map_aggregate" | ||||
|   ], | ||||
|   "torch.fx.passes.split_module": [ | ||||
|     "Any", | ||||
|     "Callable", | ||||
|     "Dict", | ||||
|     "GraphModule", | ||||
|     "List", | ||||
|     "Optional", | ||||
|     "compatibility" | ||||
|   ], | ||||
|   "torch.fx.proxy": [ | ||||
|     "assert_fn" | ||||
|   ], | ||||
| @ -1412,11 +1326,6 @@ | ||||
|   "torch.nn.intrinsic.modules": [ | ||||
|     "_FusedModule" | ||||
|   ], | ||||
|   "torch.nn.intrinsic.qat.modules.conv_fused": [ | ||||
|     "Parameter", | ||||
|     "TypeVar", | ||||
|     "fuse_conv_bn_weights" | ||||
|   ], | ||||
|   "torch.nn.intrinsic.qat.modules.linear_fused": [ | ||||
|     "Parameter", | ||||
|     "fuse_linear_bn_weights" | ||||
| @ -1436,14 +1345,6 @@ | ||||
|   "torch.nn.parallel.comm": [ | ||||
|     "List" | ||||
|   ], | ||||
|   "torch.nn.parallel.data_parallel": [ | ||||
|     "Module", | ||||
|     "chain", | ||||
|     "gather", | ||||
|     "parallel_apply", | ||||
|     "replicate", | ||||
|     "scatter_kwargs" | ||||
|   ], | ||||
|   "torch.nn.parallel.parallel_apply": [ | ||||
|     "ExceptionWrapper", | ||||
|     "autocast" | ||||
| @ -1452,8 +1353,7 @@ | ||||
|     "OrderedDict" | ||||
|   ], | ||||
|   "torch.nn.parallel.scatter_gather": [ | ||||
|     "Gather", | ||||
|     "Scatter" | ||||
|     "is_namedtuple" | ||||
|   ], | ||||
|   "torch.nn.parameter": [ | ||||
|     "OrderedDict" | ||||
| @ -1514,26 +1414,6 @@ | ||||
|     "Iterable", | ||||
|     "Optional" | ||||
|   ], | ||||
|   "torch.nn.utils.parametrizations": [ | ||||
|     "Enum", | ||||
|     "Module", | ||||
|     "Optional", | ||||
|     "Tensor", | ||||
|     "auto" | ||||
|   ], | ||||
|   "torch.nn.utils.parametrize": [ | ||||
|     "Dict", | ||||
|     "Module", | ||||
|     "ModuleDict", | ||||
|     "ModuleList", | ||||
|     "Optional", | ||||
|     "Parameter", | ||||
|     "Sequence", | ||||
|     "Tensor", | ||||
|     "Tuple", | ||||
|     "Union", | ||||
|     "contextmanager" | ||||
|   ], | ||||
|   "torch.onnx": [ | ||||
|     "Dict", | ||||
|     "OperatorExportTypes", | ||||
| @ -1550,11 +1430,6 @@ | ||||
|     "has_torch_function", | ||||
|     "push_torch_function_mode" | ||||
|   ], | ||||
|   "torch.package.analyze.find_first_use_of_broken_modules": [ | ||||
|     "Dict", | ||||
|     "List", | ||||
|     "PackagingError" | ||||
|   ], | ||||
|   "torch.package.analyze.is_from_package": [ | ||||
|     "Any", | ||||
|     "ModuleType", | ||||
| @ -1919,10 +1794,6 @@ | ||||
|     "rand", | ||||
|     "randn" | ||||
|   ], | ||||
|   "torch.torch_version": [ | ||||
|     "Any", | ||||
|     "Iterable" | ||||
|   ], | ||||
|   "torch.types": [ | ||||
|     "Any", | ||||
|     "Device", | ||||
| @ -1937,14 +1808,6 @@ | ||||
|     "enable_minidumps", | ||||
|     "enable_minidumps_on_exceptions" | ||||
|   ], | ||||
|   "torch.utils.benchmark.utils.common": [ | ||||
|     "_make_temp_dir", | ||||
|     "ordered_unique", | ||||
|     "select_unit", | ||||
|     "set_torch_threads", | ||||
|     "trim_sigfig", | ||||
|     "unit_to_english" | ||||
|   ], | ||||
|   "torch.utils.benchmark.utils.compare": [ | ||||
|     "Colorize", | ||||
|     "Table", | ||||
|  | ||||
| @ -13,6 +13,10 @@ from collections import namedtuple | ||||
| from functools import wraps | ||||
| from typing import Dict, Optional | ||||
|  | ||||
| __all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream', | ||||
|            'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms', | ||||
|            'MetricData'] | ||||
|  | ||||
| MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -14,6 +14,7 @@ import traceback | ||||
| import warnings | ||||
| from typing import Optional | ||||
|  | ||||
| __all__ = ['ErrorHandler'] | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @ -10,6 +10,7 @@ | ||||
|  | ||||
| from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler | ||||
|  | ||||
| __all__ = ['get_error_handler'] | ||||
|  | ||||
| def get_error_handler(): | ||||
|     return ErrorHandler() | ||||
|  | ||||
| @ -4,6 +4,8 @@ from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union | ||||
|  | ||||
| import torch | ||||
|  | ||||
| __all__ = ['Future', 'collect_all', 'wait_all'] | ||||
|  | ||||
| T = TypeVar("T") | ||||
| S = TypeVar("S") | ||||
|  | ||||
|  | ||||
| @ -11,6 +11,8 @@ from torch.fx.operator_schemas import normalize_function, normalize_module, Args | ||||
| if TYPE_CHECKING: | ||||
|     from .graph import Graph | ||||
|  | ||||
| __all__ = ['Node', 'map_arg', 'map_aggregate'] | ||||
|  | ||||
| BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, | ||||
|                           torch.Tensor, torch.device, torch.memory_format, torch.layout] | ||||
| base_types = BaseArgumentTypes.__args__  # type: ignore[attr-defined] | ||||
|  | ||||
| @ -9,6 +9,7 @@ from torch.fx.passes.shape_prop import TensorMetadata | ||||
| from torch.fx._compatibility import compatibility | ||||
| from itertools import chain | ||||
|  | ||||
| __all__ = ['FxGraphDrawer'] | ||||
| try: | ||||
|     import pydot | ||||
|     HAS_PYDOT = True | ||||
|  | ||||
| @ -11,6 +11,9 @@ from torch.fx.node import ( | ||||
| ) | ||||
| from torch.fx.passes.shape_prop import ShapeProp | ||||
|  | ||||
| __all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', | ||||
|            'get_size_of_node'] | ||||
|  | ||||
| @compatibility(is_backward_compatible=False) | ||||
| def replace_target_nodes_with( | ||||
|     fx_module: GraphModule, | ||||
|  | ||||
| @ -5,6 +5,7 @@ import torch.nn as nn | ||||
|  | ||||
| from torch.fx._compatibility import compatibility | ||||
|  | ||||
| __all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] | ||||
|  | ||||
| # Matching method matches the attribute name of current version to the attribute name of `target_version` | ||||
| @compatibility(is_backward_compatible=False) | ||||
|  | ||||
| @ -6,6 +6,7 @@ from torch.fx.node import Node, map_aggregate | ||||
| from typing import Any, Tuple, NamedTuple, Optional, Dict | ||||
| from torch.fx._compatibility import compatibility | ||||
|  | ||||
| __all__ = ['TensorMetadata', 'ShapeProp'] | ||||
|  | ||||
| @compatibility(is_backward_compatible=True) | ||||
| class TensorMetadata(NamedTuple): | ||||
|  | ||||
| @ -4,6 +4,8 @@ from typing import Callable, List, Dict, Any, Optional | ||||
| from torch.fx._compatibility import compatibility | ||||
| import inspect | ||||
|  | ||||
| __all__ = ['Partition', 'split_module'] | ||||
|  | ||||
| @compatibility(is_backward_compatible=True) | ||||
| class Partition: | ||||
|     def __init__(self, name: str): | ||||
|  | ||||
| @ -10,6 +10,8 @@ from torch.nn.modules.utils import _single, _pair, _triple | ||||
| from torch.nn.parameter import Parameter | ||||
| from typing import TypeVar | ||||
|  | ||||
| __all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d', | ||||
|            'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats'] | ||||
| _BN_CLASS_MAP = { | ||||
|     1: nn.BatchNorm1d, | ||||
|     2: nn.BatchNorm2d, | ||||
|  | ||||
| @ -13,6 +13,8 @@ from torch._utils import ( | ||||
|     _get_devices_properties | ||||
| ) | ||||
|  | ||||
| __all__ = ['DataParallel', 'data_parallel'] | ||||
|  | ||||
| def _check_balance(device_ids): | ||||
|     imbalance_warn = """ | ||||
|     There is an imbalance between your GPUs. You may want to exclude GPU {} which | ||||
|  | ||||
| @ -1,7 +1,15 @@ | ||||
| import torch | ||||
| from ._functions import Scatter, Gather | ||||
| import warnings | ||||
|  | ||||
| __all__ = ['scatter', 'scatter_kwargs', 'gather'] | ||||
|  | ||||
| def is_namedtuple(obj): | ||||
|     # Check if type was created from collections.namedtuple or a typing.NamedTuple. | ||||
|     warnings.warn("is_namedtuple is deprecated, please use the python checks instead") | ||||
|     return _is_namedtuple(obj) | ||||
|  | ||||
| def _is_namedtuple(obj): | ||||
|     # Check if type was created from collections.namedtuple or a typing.NamedTuple. | ||||
|     return ( | ||||
|         isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") | ||||
|  | ||||
| @ -8,6 +8,7 @@ from .. import functional as F | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| __all__ = ['orthogonal', 'spectral_norm'] | ||||
|  | ||||
| def _is_orthogonal(Q, eps=None): | ||||
|     n, k = Q.size(-2), Q.size(-1) | ||||
|  | ||||
| @ -7,6 +7,9 @@ import collections | ||||
| from contextlib import contextmanager | ||||
| from typing import Union, Optional, Dict, Tuple, Sequence | ||||
|  | ||||
| __all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations', | ||||
|            'type_before_parametrizations', 'transfer_parametrizations_and_params'] | ||||
|  | ||||
| _cache_enabled = 0 | ||||
| _cache: Dict[Tuple[int, str], Optional[Tensor]] = {} | ||||
|  | ||||
|  | ||||
| @ -2,6 +2,8 @@ from typing import Dict, List | ||||
|  | ||||
| from ..package_exporter import PackagingError | ||||
|  | ||||
| __all__ = ["find_first_use_of_broken_modules"] | ||||
|  | ||||
|  | ||||
| def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]: | ||||
|     """ | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| from typing import Any, Iterable | ||||
| from .version import __version__ as internal_version | ||||
|  | ||||
| __all__ = ['TorchVersion', 'Version', 'InvalidVersion'] | ||||
|  | ||||
| class _LazyImport: | ||||
|     """Wraps around classes lazy imported from packaging.version | ||||
|     Output of the function v in following snippets are identical: | ||||
|  | ||||
| @ -14,7 +14,7 @@ import uuid | ||||
| import torch | ||||
|  | ||||
|  | ||||
| __all__ = ["TaskSpec", "Measurement", "_make_temp_dir"] | ||||
| __all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"] | ||||
|  | ||||
|  | ||||
| _MAX_SIGNIFICANT_FIGURES = 4 | ||||
|  | ||||
| @ -7,7 +7,7 @@ from typing import DefaultDict, List, Optional, Tuple | ||||
| from torch.utils.benchmark.utils import common | ||||
| from torch import tensor as _tensor | ||||
|  | ||||
| __all__ = ["Compare"] | ||||
| __all__ = ["Colorize", "Compare"] | ||||
|  | ||||
| BEST = "\033[92m" | ||||
| GOOD = "\033[34m" | ||||
|  | ||||
| @ -46,6 +46,12 @@ HIPIFY_FINAL_RESULT: HipifyFinalResult = {} | ||||
| to their actual types.""" | ||||
| PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} | ||||
|  | ||||
| __all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter', | ||||
|            'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group', | ||||
|            'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared', | ||||
|            'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_caffe2_gpu_file', | ||||
|            'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header', | ||||
|            'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify'] | ||||
|  | ||||
| class InputError(Exception): | ||||
|     # Exception raised for errors in the input. | ||||
|  | ||||
		Reference in New Issue
	
	Block a user