[BE]: Try TCH autofixes on torch/ (#125536)

Tries TCH autofixes and see what breaks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125536
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2024-05-05 23:13:54 +00:00
committed by PyTorch MergeBot
parent ccbac091d2
commit 1dd42e42c4
49 changed files with 255 additions and 110 deletions

View File

@ -1,6 +1,6 @@
import contextlib
import functools
from typing import List, Optional
from typing import List, Optional, TYPE_CHECKING
import torch
from torch._dynamo.external_utils import call_backward, call_hook
@ -21,10 +21,12 @@ from torch.fx.experimental.proxy_tensor import (
track_tensor_tree,
)
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch.fx.proxy import Proxy
from torch.fx.traceback import preserve_node_meta, set_stack_trace
from torch.utils._traceback import CapturedTraceback
if TYPE_CHECKING:
from torch.fx.proxy import Proxy
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")

View File

@ -71,7 +71,6 @@ from .guards import (
GuardedCode,
)
from .hooks import Hooks
from .output_graph import OutputGraph
from .replay_record import ExecutionRecord
from .symbolic_convert import InstructionTranslator, SpeculationLog
from .trace_rules import is_numpy
@ -438,6 +437,9 @@ from collections import OrderedDict
from torch.utils.hooks import RemovableHandle
if typing.TYPE_CHECKING:
from .output_graph import OutputGraph
# we have to use `OrderedDict` to make `RemovableHandle` work.
_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict()

View File

@ -22,7 +22,18 @@ import warnings
import weakref
from enum import Enum
from os.path import dirname, join
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from unittest.mock import patch
import torch
@ -30,7 +41,6 @@ import torch.fx
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch import _guards
from torch._subclasses import fake_tensor
from torch._utils_internal import log_export_usage
from torch.export.dynamic_shapes import _process_dynamic_shapes
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
@ -57,7 +67,6 @@ from . import config, convert_frame, external_utils, trace_rules, utils
from .code_context import code_context
from .exc import CondOpArgsMismatchError, UserError, UserErrorType
from .mutation_guard import install_generation_tagging_init
from .types import CacheEntry, DynamoCallback
from .utils import common_constant_types, compile_times
log = logging.getLogger(__name__)
@ -70,6 +79,10 @@ null_context = contextlib.nullcontext
import sympy
if TYPE_CHECKING:
from torch._subclasses import fake_tensor
from .types import CacheEntry, DynamoCallback
# See https://github.com/python/typing/pull/240
class Unset(Enum):

View File

@ -18,7 +18,18 @@ import textwrap
import types
import weakref
from inspect import currentframe, getframeinfo
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from weakref import ReferenceType
@ -91,6 +102,9 @@ from .utils import (
tuple_iterator_len,
)
if TYPE_CHECKING:
from sympy import Symbol
log = logging.getLogger(__name__)
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
@ -1622,8 +1636,6 @@ class GuardBuilder(GuardBuilderBase):
]
if output_graph.export_constraints:
from sympy import Symbol
source_pairs: List[Tuple[Source, Source]] = []
derived_equalities: List[ # type: ignore[type-arg]
Tuple[Source, Union[Source, Symbol], Callable]

View File

@ -10,7 +10,7 @@ import sys
import traceback
import weakref
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import sympy
@ -103,6 +103,9 @@ from .variables.tensor import (
from .variables.torch_function import TensorWithTFOverrideVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
log = logging.getLogger(__name__)
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
@ -334,7 +337,6 @@ class OutputGraph:
self.global_scope = global_scope
self.local_scope = local_scope
self.root_tx = root_tx
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
# Given a source, what are the user stacks of all locations that
# accessed it?

View File

@ -58,7 +58,9 @@ from .variables import (
UserMethodVariable,
)
from .variables.base import VariableTracker
if typing.TYPE_CHECKING:
from .variables.base import VariableTracker
"""

View File

@ -5,7 +5,7 @@ import functools
import logging
import types
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
import torch._C
import torch.fx
@ -31,6 +31,9 @@ from .lazy import LazyVariableTracker
from .lists import ListVariable, TupleVariable
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)
@ -1474,7 +1477,6 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
def create_wrapped_node(
self, tx, query: "VariableTracker", score_function: "VariableTracker"
):
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
from .builder import SourcelessBuilder

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors
import weakref
from typing import Dict, List
from typing import Dict, List, TYPE_CHECKING
import torch
from torch.utils._pytree import tree_map_only
@ -16,13 +16,15 @@ from ..source import (
)
from ..utils import GLOBAL_KEY_PREFIX
from .base import VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from .base import VariableTracker
class ArgMappingException(Exception):
pass

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors
import inspect
from typing import Dict, List
from typing import Dict, List, TYPE_CHECKING
import torch.utils._pytree as pytree
@ -10,12 +10,14 @@ from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource
from ..utils import has_torch_function, is_tensor_base_attr_getter
from .base import VariableTracker
from .constant import ConstantVariable
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from .base import VariableTracker
# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues):
# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches

View File

@ -1,7 +1,7 @@
import contextlib
import inspect
from collections import defaultdict
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
@ -37,6 +37,9 @@ from torch.utils._pytree import (
tree_map_with_path,
)
if TYPE_CHECKING:
from sympy import Symbol
def key_path_to_source(kp: KeyPath) -> Source:
"""
@ -159,8 +162,6 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes):
(args, kwargs),
)
from sympy import Symbol
source_pairs: List[Tuple[Source, Source]] = []
derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
phantom_symbols: Dict[str, Symbol] = {}

View File

@ -4,12 +4,15 @@ Utils for caching the outputs of AOTAutograd
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import torch
from torch._inductor.codecache import _ident, FxGraphCachePickler
from .schemas import AOTConfig # noqa: F401
if TYPE_CHECKING:
import torch
log = logging.getLogger(__name__)

View File

@ -9,9 +9,7 @@ import math
import operator
import os
from collections import defaultdict
from typing import List, Optional, Set, Tuple, Union
import sympy
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
import torch
import torch.fx as fx
@ -29,6 +27,9 @@ from torch.fx.passes import graph_drawer
from . import config
from .compile_utils import fx_graph_cse, get_aten_target
if TYPE_CHECKING:
import sympy
AOT_PARTITIONER_DEBUG = config.debug_partitioner
log = logging.getLogger(__name__)

View File

@ -26,7 +26,6 @@ from typing import (
TypeVar,
)
import torch
from torch.utils import _pytree as pytree
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakTensorKeyDictionary
@ -35,11 +34,13 @@ log = logging.getLogger(__name__)
if TYPE_CHECKING:
import sympy
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
import sympy
import torch
"""

View File

@ -10,8 +10,6 @@ import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import (
Any,
Callable,
@ -32,6 +30,9 @@ from torch._inductor import ir
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from torch._inductor.select_algorithm import TritonTemplateCaller
from . import config

View File

@ -23,6 +23,7 @@ from typing import (
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -38,7 +39,6 @@ from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._triton import has_triton_package
from ..._dynamo.utils import counters
@ -86,6 +86,9 @@ from .common import (
from .multi_kernel import MultiKernel
from .triton_utils import config_of, signature_of, signature_to_meta
if TYPE_CHECKING:
from torch.utils._sympy.value_ranges import ValueRanges
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")

View File

@ -60,6 +60,7 @@ from typing import (
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -84,10 +85,12 @@ from torch._inductor.cudagraph_utils import (
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.storage import UntypedStorage
from torch.types import _bool
from torch.utils import _pytree as pytree
from torch.utils.weak import TensorWeakRef
if TYPE_CHECKING:
from torch.types import _bool
StorageWeakRefPointer = int
StorageDataPtr = int
NBytes = int

View File

@ -3,9 +3,7 @@ import itertools
import logging
import operator
from collections import Counter, defaultdict
from typing import Any, Dict, List, Optional, Set, Union
from sympy import Expr
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
import torch
import torch._inductor as inductor
@ -49,6 +47,9 @@ from .pre_grad import is_same_dict, save_inductor_dict
from .reinplace import reinplace_inplaceable_ops
from .split_cat import POST_GRAD_PATTERNS
if TYPE_CHECKING:
from sympy import Expr
log = logging.getLogger(__name__)
aten = torch.ops.aten

View File

@ -7,7 +7,18 @@ import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
DefaultDict,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
import sympy
@ -16,7 +27,6 @@ import torch._logging
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import defake, dynamo_timed
from torch._higher_order_ops.effects import _EffectType
from torch._logging import LazyString, trace_structured
from torch._prims_common import make_channels_last_strides_for
from torch._subclasses.fake_tensor import FakeTensor
@ -80,6 +90,9 @@ from .utils import (
)
from .virtualized import V
if TYPE_CHECKING:
from torch._higher_order_ops.effects import _EffectType
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")

View File

@ -2,11 +2,10 @@ from __future__ import annotations
import functools
import logging
from typing import cast, List, Optional, Sequence, Tuple, TypedDict
from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
import torch
from .. import config, ir
from ..ir import TensorBox
from ..lowering import (
add_layout_constraint,
@ -30,6 +29,9 @@ from ..utils import (
from ..virtualized import V
from .mm_common import filtered_configs
if TYPE_CHECKING:
from ..ir import TensorBox
log = logging.getLogger(__name__)

View File

@ -1,10 +1,12 @@
import logging
from typing import List
from typing import List, TYPE_CHECKING
from ..ir import ChoiceCaller
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
if TYPE_CHECKING:
from ..ir import ChoiceCaller
log = logging.getLogger(__name__)
uint4x2_mixed_mm_template = TritonTemplate(

View File

@ -36,7 +36,6 @@ import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import counters
from torch._prims_common import is_integer_dtype
from torch.fx import Node
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.immutable_collections import immutable_dict, immutable_list
@ -50,6 +49,9 @@ from . import config
from .decomposition import select_decomp_table
from .lowering import fallback_node_due_to_unsupported_type
if typing.TYPE_CHECKING:
from torch.fx import Node
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims

View File

@ -12,20 +12,22 @@ from __future__ import annotations
import builtins
import itertools
import operator
from typing import Optional, Sequence
from typing import Optional, Sequence, TYPE_CHECKING
import torch
from . import _dtypes_impl, _util
from ._normalizations import (
ArrayLike,
ArrayLikeOrScalar,
CastingModes,
DTypeLike,
NDArray,
NotImplementedType,
OutArray,
)
if TYPE_CHECKING:
from ._normalizations import (
ArrayLike,
ArrayLikeOrScalar,
CastingModes,
DTypeLike,
NDArray,
NotImplementedType,
OutArray,
)
def copy(

View File

@ -8,19 +8,21 @@ Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype insta
from __future__ import annotations
import functools
from typing import Optional
from typing import Optional, TYPE_CHECKING
import torch
from . import _dtypes_impl, _util
from ._normalizations import (
ArrayLike,
AxisLike,
DTypeLike,
KeepDims,
NotImplementedType,
OutArray,
)
if TYPE_CHECKING:
from ._normalizations import (
ArrayLike,
AxisLike,
DTypeLike,
KeepDims,
NotImplementedType,
OutArray,
)
def _deco_axis_expand(func):

View File

@ -40,7 +40,6 @@ from torch._utils import render_call
from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef
from torch.overrides import TorchFunctionMode
from torch.types import _bool
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
@ -52,6 +51,7 @@ from torch.utils._traceback import CapturedTraceback
if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.types import _bool
class _Unassigned:

View File

@ -34,7 +34,6 @@ from torch._C._functorch import (
maybe_get_level,
peek_interpreter_stack,
)
from torch._guards import Source
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.weak import WeakIdKeyDictionary
@ -42,6 +41,7 @@ from torch.utils.weak import WeakIdKeyDictionary
if TYPE_CHECKING:
from torch._C._autograd import CreationMeta
from torch._C._functorch import CInterpreter
from torch._guards import Source
# Import here to avoid cycle
from torch._subclasses.fake_tensor import FakeTensorMode

View File

@ -125,7 +125,6 @@ from torch.ao.quantization.fx.match_utils import _find_matches
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization import QConfigMapping
from torch.ao.ns.fx.n_shadows_utils import (
OutputProp,
@ -140,7 +139,10 @@ from torch.ao.ns.fx.n_shadows_utils import (
)
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type
from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type, TYPE_CHECKING
if TYPE_CHECKING:
from torch.ao.quantization.qconfig import QConfigAny
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

View File

@ -1,12 +1,14 @@
from __future__ import annotations
import copy
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Union, TYPE_CHECKING
import torch
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
from torch.ao.quantization.qconfig import QConfigAny
if TYPE_CHECKING:
from torch.ao.quantization.qconfig import QConfigAny
__all__ = ["QConfigMultiMapping"]

View File

@ -1,7 +1,6 @@
from collections import defaultdict
from copy import deepcopy
import torch
from typing import Any, Optional, Dict
from typing import Any, Optional, Dict, TYPE_CHECKING
import pytorch_lightning as pl # type: ignore[import]
from ._data_sparstity_utils import (
@ -10,6 +9,9 @@ from ._data_sparstity_utils import (
_get_valid_name
)
if TYPE_CHECKING:
import torch
class PostTrainingDataSparsity(pl.callbacks.Callback):
"""Lightning callback that enables post-training sparsity.

View File

@ -1,11 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union, TYPE_CHECKING
import torch
from torch.ao.quantization.utils import Pattern
from enum import Enum
if TYPE_CHECKING:
from torch.ao.quantization.utils import Pattern
__all__ = [
"BackendConfig",

View File

@ -1,13 +1,13 @@
from __future__ import annotations
from typing import Dict, List
import torch
from torch.fx import Node
from typing import Dict, List, TYPE_CHECKING
from .quantizer import QuantizationAnnotation, Quantizer
if TYPE_CHECKING:
import torch
from torch.fx import Node
__all__ = [
"ComposableQuantizer",
]

View File

@ -4,7 +4,17 @@ import itertools
import operator
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
)
import torch
import torch.nn.functional as F
@ -20,7 +30,6 @@ from torch.ao.quantization.observer import (
PlaceholderObserver,
)
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
@ -43,6 +52,9 @@ from torch.fx.passes.utils.source_matcher_utils import (
SourcePartition,
)
if TYPE_CHECKING:
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
__all__ = [
"X86InductorQuantizer",
"get_default_x86_inductor_quantization_config",

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import copy
import functools
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
import torch
import torch._dynamo as torchdynamo
@ -21,8 +21,6 @@ from torch.ao.quantization.observer import (
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
@ -34,7 +32,10 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
QuantizationConfig,
)
from torch.fx import Node
if TYPE_CHECKING:
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node
__all__ = [

View File

@ -1,6 +1,6 @@
import functools
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import torch
import torch.nn as nn
@ -15,9 +15,11 @@ from torch.distributed.utils import _to_kwargs
from torch.utils._pytree import tree_flatten, tree_map
from ._fsdp_api import MixedPrecisionPolicy
from ._fsdp_common import _cast_fp_tensor, TrainingState
from ._fsdp_param import FSDPParam
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
if TYPE_CHECKING:
from ._fsdp_param import FSDPParam
class FSDPStateContext:
"""This has state shared across FSDP states."""

View File

@ -8,6 +8,7 @@ from typing import (
Sequence,
Tuple,
cast,
TYPE_CHECKING,
)
import copy
import warnings
@ -19,7 +20,6 @@ import torch
import torch.distributed as dist
from torch.distributed import rpc
from torch.distributed import distributed_c10d
from torch.distributed._shard.metadata import ShardMetadata
import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed._shard.sharding_spec.api import (
_dispatch_custom_op,
@ -47,6 +47,9 @@ from torch.distributed.remote_device import _remote_device
from torch.utils import _pytree as pytree
import operator
if TYPE_CHECKING:
from torch.distributed._shard.metadata import ShardMetadata
# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
_sharded_tensor_current_id = 0

View File

@ -1,6 +1,6 @@
import collections.abc
import copy
from typing import Optional, List, Sequence
from typing import Optional, List, Sequence, TYPE_CHECKING
import torch
from torch.distributed import distributed_c10d as c10d
@ -10,10 +10,12 @@ from torch.distributed._shard.sharding_spec._internals import (
validate_non_overlapping_shards_metadata,
)
from torch.distributed._shard.metadata import ShardMetadata
from .metadata import TensorProperties, ShardedTensorMetadata
from .shard import Shard
if TYPE_CHECKING:
from torch.distributed._shard.metadata import ShardMetadata
def _parse_and_validate_remote_device(pg, remote_device):
if remote_device is None:
raise ValueError("remote device is None")

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
import operator
from typing import cast, Dict, List, Optional, Sequence, Tuple
from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
import torch
@ -24,7 +24,9 @@ from torch.distributed._tensor.tp_conv import (
convolution_backward_handler,
convolution_handler,
)
from torch.distributed.device_mesh import DeviceMesh
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
try:
from torch.utils import _cxx_pytree as pytree

View File

@ -5,7 +5,7 @@ sharding with the DTensor API.
import argparse
import os
from functools import cached_property
from typing import List
from typing import List, TYPE_CHECKING
import torch
@ -17,13 +17,15 @@ from torch.distributed._tensor import (
Shard,
)
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.checkpoint.metadata import (
ChunkStorageMetadata,
TensorProperties,
TensorStorageMetadata,
)
if TYPE_CHECKING:
from torch.distributed._tensor.placement_types import Placement
def get_device_type():
return (

View File

@ -11,14 +11,17 @@ from typing import (
List,
no_type_check,
Sequence,
TYPE_CHECKING,
)
import torch
import torch.nn as nn
from torch.utils.hooks import RemovableHandle
from torch.utils._python_dispatch import TorchDispatchMode
import operator
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
BYTES_PER_MB = 1024 * 1024.0

View File

@ -1,11 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
from collections import defaultdict
from typing import Dict, List, Set
from typing import Dict, List, Set, TYPE_CHECKING
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.planner import SavePlan, WriteItem
if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import MetadataIndex
__all__ = ["dedup_save_plans"]

View File

@ -1,11 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
import logging
from typing import Dict, List
from typing import Dict, List, TYPE_CHECKING
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.planner import SavePlan
if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import MetadataIndex
__all__ = ["dedup_tensors"]

View File

@ -1,16 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
from typing import TYPE_CHECKING
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.remote_device import _remote_device
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
from .utils import _element_wise_add, _normalize_device_info
if TYPE_CHECKING:
from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
# TODO: We need to refactor this code.
def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:

View File

@ -14,7 +14,7 @@ import socket
from string import Template
import time
import uuid
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events
@ -30,12 +30,14 @@ from torch.distributed.elastic.agent.server.health_check_server import (
create_healthcheck_server,
HealthCheckServer,
)
from torch.distributed.elastic.events.api import EventMetadataValue
from torch.distributed.elastic.metrics.api import prof
from torch.distributed.elastic.multiprocessing import PContext, start_processes, LogsSpecs
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
if TYPE_CHECKING:
from torch.distributed.elastic.events.api import EventMetadataValue
logger = get_logger(__name__)
__all__ = [

View File

@ -9,10 +9,12 @@
import logging
import os
import time
from concurrent.futures._base import Future
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Dict, List, Optional, TextIO
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING:
from concurrent.futures._base import Future
__all__ = ["tail_logfile", "TailLog"]

View File

@ -31,8 +31,6 @@ from torch.distributed._composable_state import _get_module_state, _State
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
from torch.distributed.utils import _apply_to_tensors
from torch.utils._mode_utils import no_dispatch
@ -46,6 +44,8 @@ from .api import (
)
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
from ._flat_param import FlatParamHandle
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"

View File

@ -15,6 +15,7 @@ from typing import (
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -58,7 +59,9 @@ from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
from torch.distributed.utils import _sync_params_and_buffers
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.hooks import RemovableHandle
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
_TORCHDISTX_AVAIL = True
try:

View File

@ -17,6 +17,7 @@ from typing import (
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -24,7 +25,6 @@ import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.distributed_c10d import _get_pg_default_device
@ -54,6 +54,9 @@ from torch.distributed.fsdp.api import (
)
from torch.utils._pytree import tree_map_only
if TYPE_CHECKING:
from torch.distributed._shard.sharded_tensor import ShardedTensor
logger = logging.getLogger(__name__)

View File

@ -11,14 +11,13 @@ from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Tuple, Type
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
import torch
import torch.distributed as dist
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils.hooks import RemovableHandle
RPC_AVAILABLE = False
if dist.is_available():
@ -44,6 +43,9 @@ from torch._utils import _get_device_index
from ..modules import Module
from .scatter_gather import gather, scatter_kwargs # noqa: F401
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
__all__ = ["DistributedDataParallel"]
logger = logging.getLogger(__name__)

View File

@ -7,7 +7,18 @@ import linecache
import os
import types
from contextlib import contextmanager
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
from typing import (
Any,
BinaryIO,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
TYPE_CHECKING,
Union,
)
from weakref import WeakValueDictionary
import torch
@ -24,9 +35,11 @@ from ._importlib import (
from ._mangling import demangle, PackageMangler
from ._package_unpickler import PackageUnpickler
from .file_structure_representation import _create_directory_from_file_list, Directory
from .glob_group import GlobPattern
from .importer import Importer
if TYPE_CHECKING:
from .glob_group import GlobPattern
__all__ = ["PackageImporter"]

View File

@ -3,13 +3,15 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, TYPE_CHECKING
from torch.autograd import _KinetoEvent
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
if TYPE_CHECKING:
from torch.autograd import _KinetoEvent
def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False):
order = reversed if reverse else lambda x: x