mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: Try TCH autofixes on torch/ (#125536)
Tries TCH autofixes and see what breaks Pull Request resolved: https://github.com/pytorch/pytorch/pull/125536 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ccbac091d2
commit
1dd42e42c4
@ -1,6 +1,6 @@
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._dynamo.external_utils import call_backward, call_hook
|
||||
@ -21,10 +21,12 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
||||
from torch.fx.proxy import Proxy
|
||||
from torch.fx.traceback import preserve_node_meta, set_stack_trace
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
|
||||
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
|
||||
|
||||
|
@ -71,7 +71,6 @@ from .guards import (
|
||||
GuardedCode,
|
||||
)
|
||||
from .hooks import Hooks
|
||||
from .output_graph import OutputGraph
|
||||
from .replay_record import ExecutionRecord
|
||||
from .symbolic_convert import InstructionTranslator, SpeculationLog
|
||||
from .trace_rules import is_numpy
|
||||
@ -438,6 +437,9 @@ from collections import OrderedDict
|
||||
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .output_graph import OutputGraph
|
||||
|
||||
# we have to use `OrderedDict` to make `RemovableHandle` work.
|
||||
_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict()
|
||||
|
||||
|
@ -22,7 +22,18 @@ import warnings
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from os.path import dirname, join
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -30,7 +41,6 @@ import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.checkpoint
|
||||
from torch import _guards
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch._utils_internal import log_export_usage
|
||||
from torch.export.dynamic_shapes import _process_dynamic_shapes
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
@ -57,7 +67,6 @@ from . import config, convert_frame, external_utils, trace_rules, utils
|
||||
from .code_context import code_context
|
||||
from .exc import CondOpArgsMismatchError, UserError, UserErrorType
|
||||
from .mutation_guard import install_generation_tagging_init
|
||||
from .types import CacheEntry, DynamoCallback
|
||||
from .utils import common_constant_types, compile_times
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -70,6 +79,10 @@ null_context = contextlib.nullcontext
|
||||
|
||||
import sympy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses import fake_tensor
|
||||
from .types import CacheEntry, DynamoCallback
|
||||
|
||||
|
||||
# See https://github.com/python/typing/pull/240
|
||||
class Unset(Enum):
|
||||
|
@ -18,7 +18,18 @@ import textwrap
|
||||
import types
|
||||
import weakref
|
||||
from inspect import currentframe, getframeinfo
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from weakref import ReferenceType
|
||||
|
||||
|
||||
@ -91,6 +102,9 @@ from .utils import (
|
||||
tuple_iterator_len,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sympy import Symbol
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
|
||||
recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
|
||||
@ -1622,8 +1636,6 @@ class GuardBuilder(GuardBuilderBase):
|
||||
]
|
||||
|
||||
if output_graph.export_constraints:
|
||||
from sympy import Symbol
|
||||
|
||||
source_pairs: List[Tuple[Source, Source]] = []
|
||||
derived_equalities: List[ # type: ignore[type-arg]
|
||||
Tuple[Source, Union[Source, Symbol], Callable]
|
||||
|
@ -10,7 +10,7 @@ import sys
|
||||
import traceback
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -103,6 +103,9 @@ from .variables.tensor import (
|
||||
|
||||
from .variables.torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
|
||||
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
||||
@ -334,7 +337,6 @@ class OutputGraph:
|
||||
self.global_scope = global_scope
|
||||
self.local_scope = local_scope
|
||||
self.root_tx = root_tx
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
|
||||
|
||||
# Given a source, what are the user stacks of all locations that
|
||||
# accessed it?
|
||||
|
@ -58,7 +58,9 @@ from .variables import (
|
||||
UserMethodVariable,
|
||||
)
|
||||
|
||||
from .variables.base import VariableTracker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .variables.base import VariableTracker
|
||||
|
||||
|
||||
"""
|
||||
|
@ -5,7 +5,7 @@ import functools
|
||||
import logging
|
||||
import types
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.fx
|
||||
@ -31,6 +31,9 @@ from .lazy import LazyVariableTracker
|
||||
from .lists import ListVariable, TupleVariable
|
||||
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -1474,7 +1477,6 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def create_wrapped_node(
|
||||
self, tx, query: "VariableTracker", score_function: "VariableTracker"
|
||||
):
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import weakref
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map_only
|
||||
@ -16,13 +16,15 @@ from ..source import (
|
||||
)
|
||||
from ..utils import GLOBAL_KEY_PREFIX
|
||||
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .dicts import ConstDictVariable
|
||||
from .lists import ListVariable
|
||||
from .misc import GetAttrVariable
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
class ArgMappingException(Exception):
|
||||
pass
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
@ -10,12 +10,14 @@ from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GlobalSource
|
||||
from ..utils import has_torch_function, is_tensor_base_attr_getter
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues):
|
||||
# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches
|
||||
|
@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -37,6 +37,9 @@ from torch.utils._pytree import (
|
||||
tree_map_with_path,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sympy import Symbol
|
||||
|
||||
|
||||
def key_path_to_source(kp: KeyPath) -> Source:
|
||||
"""
|
||||
@ -159,8 +162,6 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes):
|
||||
(args, kwargs),
|
||||
)
|
||||
|
||||
from sympy import Symbol
|
||||
|
||||
source_pairs: List[Tuple[Source, Source]] = []
|
||||
derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
|
||||
phantom_symbols: Dict[str, Symbol] = {}
|
||||
|
@ -4,12 +4,15 @@ Utils for caching the outputs of AOTAutograd
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._inductor.codecache import _ident, FxGraphCachePickler
|
||||
|
||||
from .schemas import AOTConfig # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -9,9 +9,7 @@ import math
|
||||
import operator
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
|
||||
import sympy
|
||||
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -29,6 +27,9 @@ from torch.fx.passes import graph_drawer
|
||||
from . import config
|
||||
from .compile_utils import fx_graph_cse, get_aten_target
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sympy
|
||||
|
||||
|
||||
AOT_PARTITIONER_DEBUG = config.debug_partitioner
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -26,7 +26,6 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
@ -35,11 +34,13 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sympy
|
||||
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
||||
# imported in user code.
|
||||
|
||||
import sympy
|
||||
import torch
|
||||
|
||||
|
||||
"""
|
||||
|
@ -10,8 +10,6 @@ import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ctypes import byref, c_size_t, c_void_p
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -32,6 +30,9 @@ from torch._inductor import ir
|
||||
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
|
||||
from torch._inductor.select_algorithm import TritonTemplateCaller
|
||||
|
||||
from . import config
|
||||
|
@ -23,6 +23,7 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -38,7 +39,6 @@ from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
@ -86,6 +86,9 @@ from .common import (
|
||||
from .multi_kernel import MultiKernel
|
||||
from .triton_utils import config_of, signature_of, signature_to_meta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
|
@ -60,6 +60,7 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -84,10 +85,12 @@ from torch._inductor.cudagraph_utils import (
|
||||
)
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.storage import UntypedStorage
|
||||
from torch.types import _bool
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils.weak import TensorWeakRef
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _bool
|
||||
|
||||
StorageWeakRefPointer = int
|
||||
StorageDataPtr = int
|
||||
NBytes = int
|
||||
|
@ -3,9 +3,7 @@ import itertools
|
||||
import logging
|
||||
import operator
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from sympy import Expr
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor as inductor
|
||||
@ -49,6 +47,9 @@ from .pre_grad import is_same_dict, save_inductor_dict
|
||||
from .reinplace import reinplace_inplaceable_ops
|
||||
from .split_cat import POST_GRAD_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sympy import Expr
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
@ -7,7 +7,18 @@ import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sympy
|
||||
|
||||
@ -16,7 +27,6 @@ import torch._logging
|
||||
import torch.fx
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._dynamo.utils import defake, dynamo_timed
|
||||
from torch._higher_order_ops.effects import _EffectType
|
||||
from torch._logging import LazyString, trace_structured
|
||||
from torch._prims_common import make_channels_last_strides_for
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
@ -80,6 +90,9 @@ from .utils import (
|
||||
)
|
||||
from .virtualized import V
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._higher_order_ops.effects import _EffectType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
|
||||
|
@ -2,11 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import cast, List, Optional, Sequence, Tuple, TypedDict
|
||||
from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
|
||||
|
||||
import torch
|
||||
from .. import config, ir
|
||||
from ..ir import TensorBox
|
||||
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
@ -30,6 +29,9 @@ from ..utils import (
|
||||
from ..virtualized import V
|
||||
from .mm_common import filtered_configs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import TensorBox
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
from ..ir import ChoiceCaller
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import ChoiceCaller
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
uint4x2_mixed_mm_template = TritonTemplate(
|
||||
|
@ -36,7 +36,6 @@ import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.fx import Node
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
@ -50,6 +49,9 @@ from . import config
|
||||
from .decomposition import select_decomp_table
|
||||
from .lowering import fallback_node_due_to_unsupported_type
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from torch.fx import Node
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
@ -12,20 +12,22 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from . import _dtypes_impl, _util
|
||||
from ._normalizations import (
|
||||
ArrayLike,
|
||||
ArrayLikeOrScalar,
|
||||
CastingModes,
|
||||
DTypeLike,
|
||||
NDArray,
|
||||
NotImplementedType,
|
||||
OutArray,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._normalizations import (
|
||||
ArrayLike,
|
||||
ArrayLikeOrScalar,
|
||||
CastingModes,
|
||||
DTypeLike,
|
||||
NDArray,
|
||||
NotImplementedType,
|
||||
OutArray,
|
||||
)
|
||||
|
||||
|
||||
def copy(
|
||||
|
@ -8,19 +8,21 @@ Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype insta
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from . import _dtypes_impl, _util
|
||||
from ._normalizations import (
|
||||
ArrayLike,
|
||||
AxisLike,
|
||||
DTypeLike,
|
||||
KeepDims,
|
||||
NotImplementedType,
|
||||
OutArray,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._normalizations import (
|
||||
ArrayLike,
|
||||
AxisLike,
|
||||
DTypeLike,
|
||||
KeepDims,
|
||||
NotImplementedType,
|
||||
OutArray,
|
||||
)
|
||||
|
||||
|
||||
def _deco_axis_expand(func):
|
||||
|
@ -40,7 +40,6 @@ from torch._utils import render_call
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.types import _bool
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import (
|
||||
is_traceable_wrapper_subclass,
|
||||
@ -52,6 +51,7 @@ from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.types import _bool
|
||||
|
||||
|
||||
class _Unassigned:
|
||||
|
@ -34,7 +34,6 @@ from torch._C._functorch import (
|
||||
maybe_get_level,
|
||||
peek_interpreter_stack,
|
||||
)
|
||||
from torch._guards import Source
|
||||
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils.weak import WeakIdKeyDictionary
|
||||
@ -42,6 +41,7 @@ from torch.utils.weak import WeakIdKeyDictionary
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._autograd import CreationMeta
|
||||
from torch._C._functorch import CInterpreter
|
||||
from torch._guards import Source
|
||||
|
||||
# Import here to avoid cycle
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
@ -125,7 +125,6 @@ from torch.ao.quantization.fx.match_utils import _find_matches
|
||||
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
|
||||
from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
|
||||
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
from torch.ao.ns.fx.n_shadows_utils import (
|
||||
OutputProp,
|
||||
@ -140,7 +139,10 @@ from torch.ao.ns.fx.n_shadows_utils import (
|
||||
)
|
||||
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
|
||||
|
||||
from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type
|
||||
from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
|
||||
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
|
@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Any, Callable, Dict, List, Union, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
|
||||
__all__ = ["QConfigMultiMapping"]
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from typing import Any, Optional, Dict
|
||||
from typing import Any, Optional, Dict, TYPE_CHECKING
|
||||
import pytorch_lightning as pl # type: ignore[import]
|
||||
|
||||
from ._data_sparstity_utils import (
|
||||
@ -10,6 +9,9 @@ from ._data_sparstity_utils import (
|
||||
_get_valid_name
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
class PostTrainingDataSparsity(pl.callbacks.Callback):
|
||||
"""Lightning callback that enables post-training sparsity.
|
||||
|
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.utils import Pattern
|
||||
from enum import Enum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.ao.quantization.utils import Pattern
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BackendConfig",
|
||||
|
@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from torch.fx import Node
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from .quantizer import QuantizationAnnotation, Quantizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
|
||||
__all__ = [
|
||||
"ComposableQuantizer",
|
||||
]
|
||||
|
@ -4,7 +4,17 @@ import itertools
|
||||
import operator
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -20,7 +30,6 @@ from torch.ao.quantization.observer import (
|
||||
PlaceholderObserver,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
from torch.ao.quantization.quantizer.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
QuantizationSpec,
|
||||
@ -43,6 +52,9 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
||||
SourcePartition,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
|
||||
__all__ = [
|
||||
"X86InductorQuantizer",
|
||||
"get_default_x86_inductor_quantization_config",
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import copy
|
||||
import functools
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
@ -21,8 +21,6 @@ from torch.ao.quantization.observer import (
|
||||
PlaceholderObserver,
|
||||
)
|
||||
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
|
||||
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
||||
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
@ -34,7 +32,10 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
|
||||
from torch.fx import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
from torch.fx import Node
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -15,9 +15,11 @@ from torch.distributed.utils import _to_kwargs
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
from ._fsdp_api import MixedPrecisionPolicy
|
||||
from ._fsdp_common import _cast_fp_tensor, TrainingState
|
||||
from ._fsdp_param import FSDPParam
|
||||
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._fsdp_param import FSDPParam
|
||||
|
||||
|
||||
class FSDPStateContext:
|
||||
"""This has state shared across FSDP states."""
|
||||
|
@ -8,6 +8,7 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
import copy
|
||||
import warnings
|
||||
@ -19,7 +20,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import rpc
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
import torch.distributed._shard.sharding_spec as shard_spec
|
||||
from torch.distributed._shard.sharding_spec.api import (
|
||||
_dispatch_custom_op,
|
||||
@ -47,6 +47,9 @@ from torch.distributed.remote_device import _remote_device
|
||||
from torch.utils import _pytree as pytree
|
||||
import operator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
||||
# Tracking for sharded tensor objects.
|
||||
_sharded_tensor_lock = threading.Lock()
|
||||
_sharded_tensor_current_id = 0
|
||||
|
@ -1,6 +1,6 @@
|
||||
import collections.abc
|
||||
import copy
|
||||
from typing import Optional, List, Sequence
|
||||
from typing import Optional, List, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
@ -10,10 +10,12 @@ from torch.distributed._shard.sharding_spec._internals import (
|
||||
validate_non_overlapping_shards_metadata,
|
||||
)
|
||||
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
from .metadata import TensorProperties, ShardedTensorMetadata
|
||||
from .shard import Shard
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
||||
def _parse_and_validate_remote_device(pg, remote_device):
|
||||
if remote_device is None:
|
||||
raise ValueError("remote device is None")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import functools
|
||||
import operator
|
||||
from typing import cast, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -24,7 +24,9 @@ from torch.distributed._tensor.tp_conv import (
|
||||
convolution_backward_handler,
|
||||
convolution_handler,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
try:
|
||||
from torch.utils import _cxx_pytree as pytree
|
||||
|
@ -5,7 +5,7 @@ sharding with the DTensor API.
|
||||
import argparse
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -17,13 +17,15 @@ from torch.distributed._tensor import (
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding
|
||||
from torch.distributed._tensor.placement_types import Placement
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
ChunkStorageMetadata,
|
||||
TensorProperties,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed._tensor.placement_types import Placement
|
||||
|
||||
|
||||
def get_device_type():
|
||||
return (
|
||||
|
@ -11,14 +11,17 @@ from typing import (
|
||||
List,
|
||||
no_type_check,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
import operator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
BYTES_PER_MB = 1024 * 1024.0
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set
|
||||
from typing import Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
from torch.distributed.checkpoint.planner import SavePlan, WriteItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
|
||||
__all__ = ["dedup_save_plans"]
|
||||
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
from torch.distributed.checkpoint.planner import SavePlan
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
|
||||
__all__ = ["dedup_tensors"]
|
||||
|
||||
|
||||
|
@ -1,16 +1,19 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
|
||||
from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
|
||||
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
|
||||
from .utils import _element_wise_add, _normalize_device_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
|
||||
|
||||
|
||||
# TODO: We need to refactor this code.
|
||||
def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
||||
|
@ -14,7 +14,7 @@ import socket
|
||||
from string import Template
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch.distributed.elastic.timer as timer
|
||||
from torch.distributed.elastic import events
|
||||
@ -30,12 +30,14 @@ from torch.distributed.elastic.agent.server.health_check_server import (
|
||||
create_healthcheck_server,
|
||||
HealthCheckServer,
|
||||
)
|
||||
from torch.distributed.elastic.events.api import EventMetadataValue
|
||||
from torch.distributed.elastic.metrics.api import prof
|
||||
from torch.distributed.elastic.multiprocessing import PContext, start_processes, LogsSpecs
|
||||
from torch.distributed.elastic.utils import macros
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed.elastic.events.api import EventMetadataValue
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
|
@ -9,10 +9,12 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures._base import Future
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from threading import Event
|
||||
from typing import Dict, List, Optional, TextIO
|
||||
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from concurrent.futures._base import Future
|
||||
|
||||
__all__ = ["tail_logfile", "TailLog"]
|
||||
|
||||
|
@ -31,8 +31,6 @@ from torch.distributed._composable_state import _get_module_state, _State
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
_CHECKPOINT_PREFIX,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
|
||||
from torch.distributed.utils import _apply_to_tensors
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
@ -46,6 +44,8 @@ from .api import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
|
||||
from ._flat_param import FlatParamHandle
|
||||
|
||||
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
|
||||
|
@ -15,6 +15,7 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -58,7 +59,9 @@ from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
|
||||
from torch.distributed.utils import _sync_params_and_buffers
|
||||
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
_TORCHDISTX_AVAIL = True
|
||||
try:
|
||||
|
@ -17,6 +17,7 @@ from typing import (
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -24,7 +25,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._state_dict_utils import _gather_state_dict
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
from torch.distributed.distributed_c10d import _get_pg_default_device
|
||||
@ -54,6 +54,9 @@ from torch.distributed.fsdp.api import (
|
||||
)
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -11,14 +11,13 @@ from collections import defaultdict, deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Function, Variable
|
||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
RPC_AVAILABLE = False
|
||||
if dist.is_available():
|
||||
@ -44,6 +43,9 @@ from torch._utils import _get_device_index
|
||||
from ..modules import Module
|
||||
from .scatter_gather import gather, scatter_kwargs # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
__all__ = ["DistributedDataParallel"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -7,7 +7,18 @@ import linecache
|
||||
import os
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@ -24,9 +35,11 @@ from ._importlib import (
|
||||
from ._mangling import demangle, PackageMangler
|
||||
from ._package_unpickler import PackageUnpickler
|
||||
from .file_structure_representation import _create_directory_from_file_list, Directory
|
||||
from .glob_group import GlobPattern
|
||||
from .importer import Importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .glob_group import GlobPattern
|
||||
|
||||
__all__ = ["PackageImporter"]
|
||||
|
||||
|
||||
|
@ -3,13 +3,15 @@ import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from torch.autograd import _KinetoEvent
|
||||
from torch.autograd.profiler import profile
|
||||
|
||||
from torch.profiler import DeviceType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.autograd import _KinetoEvent
|
||||
|
||||
|
||||
def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False):
|
||||
order = reversed if reverse else lambda x: x
|
||||
|
Reference in New Issue
Block a user