mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add __all__ to torch.distributed, futures, fx, nn, package, benchmark submodules (#80520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80520 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
81ca2ff353
commit
4bf076e964
@ -459,12 +459,6 @@
|
||||
"torch.distributed.elastic.metrics": [
|
||||
"Optional"
|
||||
],
|
||||
"torch.distributed.elastic.metrics.api": [
|
||||
"Dict",
|
||||
"Optional",
|
||||
"namedtuple",
|
||||
"wraps"
|
||||
],
|
||||
"torch.distributed.elastic.multiprocessing": [
|
||||
"Callable",
|
||||
"Dict",
|
||||
@ -508,12 +502,6 @@
|
||||
"get_logger",
|
||||
"wraps"
|
||||
],
|
||||
"torch.distributed.elastic.multiprocessing.errors.error_handler": [
|
||||
"Optional"
|
||||
],
|
||||
"torch.distributed.elastic.multiprocessing.errors.handlers": [
|
||||
"ErrorHandler"
|
||||
],
|
||||
"torch.distributed.elastic.multiprocessing.redirects": [
|
||||
"contextmanager",
|
||||
"partial",
|
||||
@ -920,15 +908,7 @@
|
||||
"svd_lowrank"
|
||||
],
|
||||
"torch.futures": [
|
||||
"Callable",
|
||||
"Future",
|
||||
"Generic",
|
||||
"List",
|
||||
"Optional",
|
||||
"Type",
|
||||
"TypeVar",
|
||||
"Union",
|
||||
"cast"
|
||||
"Future"
|
||||
],
|
||||
"torch.fx": [
|
||||
"ProxyableClassMeta",
|
||||
@ -1033,25 +1013,6 @@
|
||||
"Tuple",
|
||||
"compatibility"
|
||||
],
|
||||
"torch.fx.node": [
|
||||
"Any",
|
||||
"ArgsKwargsPair",
|
||||
"Argument",
|
||||
"BaseArgumentTypes",
|
||||
"Callable",
|
||||
"Dict",
|
||||
"List",
|
||||
"Optional",
|
||||
"Set",
|
||||
"Target",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"compatibility",
|
||||
"immutable_dict",
|
||||
"immutable_list",
|
||||
"normalize_function",
|
||||
"normalize_module"
|
||||
],
|
||||
"torch.fx.operator_schemas": [
|
||||
"Any",
|
||||
"Callable",
|
||||
@ -1072,53 +1033,6 @@
|
||||
"chain",
|
||||
"compatibility"
|
||||
],
|
||||
"torch.fx.passes.graph_manipulation": [
|
||||
"Any",
|
||||
"Argument",
|
||||
"Dict",
|
||||
"Graph",
|
||||
"GraphModule",
|
||||
"List",
|
||||
"NamedTuple",
|
||||
"Node",
|
||||
"Optional",
|
||||
"ShapeProp",
|
||||
"Target",
|
||||
"Tuple",
|
||||
"compatibility",
|
||||
"lift_lowering_attrs_to_nodes",
|
||||
"map_aggregate",
|
||||
"map_arg"
|
||||
],
|
||||
"torch.fx.passes.param_fetch": [
|
||||
"Any",
|
||||
"Callable",
|
||||
"Dict",
|
||||
"GraphModule",
|
||||
"List",
|
||||
"Tuple",
|
||||
"Type",
|
||||
"compatibility"
|
||||
],
|
||||
"torch.fx.passes.shape_prop": [
|
||||
"Any",
|
||||
"Dict",
|
||||
"NamedTuple",
|
||||
"Node",
|
||||
"Optional",
|
||||
"Tuple",
|
||||
"compatibility",
|
||||
"map_aggregate"
|
||||
],
|
||||
"torch.fx.passes.split_module": [
|
||||
"Any",
|
||||
"Callable",
|
||||
"Dict",
|
||||
"GraphModule",
|
||||
"List",
|
||||
"Optional",
|
||||
"compatibility"
|
||||
],
|
||||
"torch.fx.proxy": [
|
||||
"assert_fn"
|
||||
],
|
||||
@ -1412,11 +1326,6 @@
|
||||
"torch.nn.intrinsic.modules": [
|
||||
"_FusedModule"
|
||||
],
|
||||
"torch.nn.intrinsic.qat.modules.conv_fused": [
|
||||
"Parameter",
|
||||
"TypeVar",
|
||||
"fuse_conv_bn_weights"
|
||||
],
|
||||
"torch.nn.intrinsic.qat.modules.linear_fused": [
|
||||
"Parameter",
|
||||
"fuse_linear_bn_weights"
|
||||
@ -1436,14 +1345,6 @@
|
||||
"torch.nn.parallel.comm": [
|
||||
"List"
|
||||
],
|
||||
"torch.nn.parallel.data_parallel": [
|
||||
"Module",
|
||||
"chain",
|
||||
"gather",
|
||||
"parallel_apply",
|
||||
"replicate",
|
||||
"scatter_kwargs"
|
||||
],
|
||||
"torch.nn.parallel.parallel_apply": [
|
||||
"ExceptionWrapper",
|
||||
"autocast"
|
||||
@ -1452,8 +1353,7 @@
|
||||
"OrderedDict"
|
||||
],
|
||||
"torch.nn.parallel.scatter_gather": [
|
||||
"Gather",
|
||||
"Scatter"
|
||||
"is_namedtuple"
|
||||
],
|
||||
"torch.nn.parameter": [
|
||||
"OrderedDict"
|
||||
@ -1514,26 +1414,6 @@
|
||||
"Iterable",
|
||||
"Optional"
|
||||
],
|
||||
"torch.nn.utils.parametrizations": [
|
||||
"Enum",
|
||||
"Module",
|
||||
"Optional",
|
||||
"Tensor",
|
||||
"auto"
|
||||
],
|
||||
"torch.nn.utils.parametrize": [
|
||||
"Dict",
|
||||
"Module",
|
||||
"ModuleDict",
|
||||
"ModuleList",
|
||||
"Optional",
|
||||
"Parameter",
|
||||
"Sequence",
|
||||
"Tensor",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"contextmanager"
|
||||
],
|
||||
"torch.onnx": [
|
||||
"Dict",
|
||||
"OperatorExportTypes",
|
||||
@ -1550,11 +1430,6 @@
|
||||
"has_torch_function",
|
||||
"push_torch_function_mode"
|
||||
],
|
||||
"torch.package.analyze.find_first_use_of_broken_modules": [
|
||||
"Dict",
|
||||
"List",
|
||||
"PackagingError"
|
||||
],
|
||||
"torch.package.analyze.is_from_package": [
|
||||
"Any",
|
||||
"ModuleType",
|
||||
@ -1919,10 +1794,6 @@
|
||||
"rand",
|
||||
"randn"
|
||||
],
|
||||
"torch.torch_version": [
|
||||
"Any",
|
||||
"Iterable"
|
||||
],
|
||||
"torch.types": [
|
||||
"Any",
|
||||
"Device",
|
||||
@ -1937,14 +1808,6 @@
|
||||
"enable_minidumps",
|
||||
"enable_minidumps_on_exceptions"
|
||||
],
|
||||
"torch.utils.benchmark.utils.common": [
|
||||
"_make_temp_dir",
|
||||
"ordered_unique",
|
||||
"select_unit",
|
||||
"set_torch_threads",
|
||||
"trim_sigfig",
|
||||
"unit_to_english"
|
||||
],
|
||||
"torch.utils.benchmark.utils.compare": [
|
||||
"Colorize",
|
||||
"Table",
|
||||
|
||||
@ -13,6 +13,10 @@ from collections import namedtuple
|
||||
from functools import wraps
|
||||
from typing import Dict, Optional
|
||||
|
||||
__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream',
|
||||
'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms',
|
||||
'MetricData']
|
||||
|
||||
MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ import traceback
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
__all__ = ['ErrorHandler']
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
|
||||
|
||||
__all__ = ['get_error_handler']
|
||||
|
||||
def get_error_handler():
|
||||
return ErrorHandler()
|
||||
|
||||
@ -4,6 +4,8 @@ from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['Future', 'collect_all', 'wait_all']
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
@ -11,6 +11,8 @@ from torch.fx.operator_schemas import normalize_function, normalize_module, Args
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
__all__ = ['Node', 'map_arg', 'map_aggregate']
|
||||
|
||||
BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
|
||||
torch.Tensor, torch.device, torch.memory_format, torch.layout]
|
||||
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
|
||||
|
||||
@ -9,6 +9,7 @@ from torch.fx.passes.shape_prop import TensorMetadata
|
||||
from torch.fx._compatibility import compatibility
|
||||
from itertools import chain
|
||||
|
||||
__all__ = ['FxGraphDrawer']
|
||||
try:
|
||||
import pydot
|
||||
HAS_PYDOT = True
|
||||
|
||||
@ -11,6 +11,9 @@ from torch.fx.node import (
|
||||
)
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
|
||||
__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
|
||||
'get_size_of_node']
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def replace_target_nodes_with(
|
||||
fx_module: GraphModule,
|
||||
|
||||
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
|
||||
|
||||
# Matching method matches the attribute name of current version to the attribute name of `target_version`
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
||||
@ -6,6 +6,7 @@ from torch.fx.node import Node, map_aggregate
|
||||
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
__all__ = ['TensorMetadata', 'ShapeProp']
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TensorMetadata(NamedTuple):
|
||||
|
||||
@ -4,6 +4,8 @@ from typing import Callable, List, Dict, Any, Optional
|
||||
from torch.fx._compatibility import compatibility
|
||||
import inspect
|
||||
|
||||
__all__ = ['Partition', 'split_module']
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Partition:
|
||||
def __init__(self, name: str):
|
||||
|
||||
@ -10,6 +10,8 @@ from torch.nn.modules.utils import _single, _pair, _triple
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import TypeVar
|
||||
|
||||
__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d',
|
||||
'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats']
|
||||
_BN_CLASS_MAP = {
|
||||
1: nn.BatchNorm1d,
|
||||
2: nn.BatchNorm2d,
|
||||
|
||||
@ -13,6 +13,8 @@ from torch._utils import (
|
||||
_get_devices_properties
|
||||
)
|
||||
|
||||
__all__ = ['DataParallel', 'data_parallel']
|
||||
|
||||
def _check_balance(device_ids):
|
||||
imbalance_warn = """
|
||||
There is an imbalance between your GPUs. You may want to exclude GPU {} which
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
import torch
|
||||
from ._functions import Scatter, Gather
|
||||
import warnings
|
||||
|
||||
__all__ = ['scatter', 'scatter_kwargs', 'gather']
|
||||
|
||||
def is_namedtuple(obj):
|
||||
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
|
||||
warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
|
||||
return _is_namedtuple(obj)
|
||||
|
||||
def _is_namedtuple(obj):
|
||||
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
|
||||
return (
|
||||
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
|
||||
|
||||
@ -8,6 +8,7 @@ from .. import functional as F
|
||||
|
||||
from typing import Optional
|
||||
|
||||
__all__ = ['orthogonal', 'spectral_norm']
|
||||
|
||||
def _is_orthogonal(Q, eps=None):
|
||||
n, k = Q.size(-2), Q.size(-1)
|
||||
|
||||
@ -7,6 +7,9 @@ import collections
|
||||
from contextlib import contextmanager
|
||||
from typing import Union, Optional, Dict, Tuple, Sequence
|
||||
|
||||
__all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations',
|
||||
'type_before_parametrizations', 'transfer_parametrizations_and_params']
|
||||
|
||||
_cache_enabled = 0
|
||||
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
|
||||
|
||||
|
||||
@ -2,6 +2,8 @@ from typing import Dict, List
|
||||
|
||||
from ..package_exporter import PackagingError
|
||||
|
||||
__all__ = ["find_first_use_of_broken_modules"]
|
||||
|
||||
|
||||
def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]:
|
||||
"""
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from typing import Any, Iterable
|
||||
from .version import __version__ as internal_version
|
||||
|
||||
__all__ = ['TorchVersion', 'Version', 'InvalidVersion']
|
||||
|
||||
class _LazyImport:
|
||||
"""Wraps around classes lazy imported from packaging.version
|
||||
Output of the function v in following snippets are identical:
|
||||
|
||||
@ -14,7 +14,7 @@ import uuid
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["TaskSpec", "Measurement", "_make_temp_dir"]
|
||||
__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"]
|
||||
|
||||
|
||||
_MAX_SIGNIFICANT_FIGURES = 4
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import DefaultDict, List, Optional, Tuple
|
||||
from torch.utils.benchmark.utils import common
|
||||
from torch import tensor as _tensor
|
||||
|
||||
__all__ = ["Compare"]
|
||||
__all__ = ["Colorize", "Compare"]
|
||||
|
||||
BEST = "\033[92m"
|
||||
GOOD = "\033[34m"
|
||||
|
||||
@ -46,6 +46,12 @@ HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
|
||||
to their actual types."""
|
||||
PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
|
||||
|
||||
__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
|
||||
'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
|
||||
'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
|
||||
'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_caffe2_gpu_file',
|
||||
'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
|
||||
'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify']
|
||||
|
||||
class InputError(Exception):
|
||||
# Exception raised for errors in the input.
|
||||
|
||||
Reference in New Issue
Block a user