[BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/ and torch/[e-n]*/ (#129769)

See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2024-08-04 10:41:34 +08:00
committed by PyTorch MergeBot
parent 2714adce20
commit f3fce597e9
86 changed files with 152 additions and 111 deletions

View File

@ -51,10 +51,8 @@ ISORT_SKIPLIST = re.compile(
# torch/_i*/**
# torch/_[j-z]*/**
# torch/[a-c]*/**
"torch/[a-c]*/**",
# torch/d*/**
# torch/[e-n]*/**
"torch/[e-n]*/**",
# torch/[o-z]*/**
],
),

View File

@ -18,10 +18,11 @@ from typing import (
Union,
)
import torch
from torch import _C, _ops, Tensor
from torch.utils._exposed_in import exposed_in
from .. import _C, _library, _ops, autograd, library, Tensor
from . import utils
from . import autograd, utils
device_types_t = Optional[Union[str, Sequence[str]]]
@ -363,10 +364,10 @@ class CustomOpDef:
self._backend_fns[device_type] = wrapped_fn
return fn
from torch._library.utils import get_device_arg_index, has_tensor_arg
if device_types is not None and not has_tensor_arg(self._opoverload._schema):
device_arg_index = get_device_arg_index(self._opoverload._schema)
if device_types is not None and not utils.has_tensor_arg(
self._opoverload._schema
):
device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
if device_arg_index is None:
raise ValueError(
"Functions without tensor inputs are required to have a `device: torch.device` argument"
@ -566,7 +567,7 @@ class CustomOpDef:
"""
schema = self._opoverload._schema
if not _library.utils.is_functional_schema(schema):
if not utils.is_functional_schema(schema):
raise RuntimeError(
f"Cannot register autograd formula for non-functional operator "
f"{self} with schema {schema}. Please create "
@ -593,11 +594,11 @@ class CustomOpDef:
schema_str,
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
)
self._opoverload = _library.utils.lookup_op(self._qualname)
self._opoverload = utils.lookup_op(self._qualname)
def fake_impl(*args, **kwargs):
if self._abstract_fn is None:
if _library.utils.can_generate_trivial_fake_impl(self._opoverload):
if utils.can_generate_trivial_fake_impl(self._opoverload):
return None
raise RuntimeError(
f"There was no fake impl registered for {self}. "
@ -609,24 +610,24 @@ class CustomOpDef:
lib._register_fake(self._name, fake_impl, _stacklevel=4)
autograd_impl = _library.autograd.make_autograd_impl(self._opoverload, self)
autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
schema = self._opoverload._schema
if schema.is_mutable:
def adinplaceorview_impl(keyset, *args, **kwargs):
for arg, val in _library.utils.zip_schema(schema, args, kwargs):
for arg, val in utils.zip_schema(schema, args, kwargs):
if not arg.alias_info:
continue
if not arg.alias_info.is_write:
continue
if isinstance(val, Tensor):
autograd.graph.increment_version(val)
torch.autograd.graph.increment_version(val)
elif isinstance(val, (tuple, list)):
for v in val:
if isinstance(v, Tensor):
autograd.graph.increment_version(v)
torch.autograd.graph.increment_version(v)
with _C._AutoDispatchBelowADInplaceOrView():
return self._opoverload.redispatch(
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
@ -783,18 +784,20 @@ class CustomOpDef:
# decorator.
OPDEF_TO_LIB: Dict[str, "library.Library"] = {}
OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
def get_library_allowing_overwrite(namespace: str, name: str) -> "library.Library":
def get_library_allowing_overwrite(
namespace: str, name: str
) -> "torch.library.Library":
qualname = f"{namespace}::{name}"
if qualname in OPDEF_TO_LIB:
OPDEF_TO_LIB[qualname]._destroy()
del OPDEF_TO_LIB[qualname]
lib = library.Library(namespace, "FRAGMENT")
lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
OPDEF_TO_LIB[qualname] = lib
return lib

View File

@ -7,14 +7,15 @@ for which gradients should be computed with the ``requires_grad=True`` keyword.
As of now, we only support autograd for floating point :class:`Tensor` types (
half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
"""
import warnings
from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union
from typing import cast, List, Optional, Sequence, Tuple, Union
import torch
from torch import _vmap_internals
from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like
from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
from .. import _vmap_internals
from ..overrides import handle_torch_function, has_torch_function, is_tensor_like
from . import forward_ad, functional, graph
from .anomaly_mode import detect_anomaly, set_detect_anomaly
from .function import Function, NestedIOFunction
@ -29,9 +30,9 @@ from .grad_mode import (
)
from .gradcheck import gradcheck, gradgradcheck
from .graph import _engine_run_backward
from .variable import Variable
__all__ = [
"Variable",
"Function",
@ -575,7 +576,6 @@ from torch._C._autograd import (
ProfilerEvent,
SavedTensor,
)
from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
from . import profiler

View File

@ -5,7 +5,7 @@ from typing_extensions import deprecated
import torch
import torch._utils
from ..function import Function
from torch.autograd.function import Function
class Type(Function):

View File

@ -4,6 +4,7 @@ import warnings
import torch
__all__ = ["detect_anomaly", "set_detect_anomaly"]

View File

@ -1,12 +1,13 @@
# mypy: allow-untyped-defs
import os
from collections import namedtuple
from typing import Any
import torch
from .grad_mode import _DecoratorContextManager
__all__ = [
"UnpackedDualTensor",
"enter_dual_level",

View File

@ -14,6 +14,7 @@ import torch.utils.hooks as hooks
from torch._C import _functions
from torch._functorch.autograd_function import custom_function_call
__all__ = [
"FunctionCtx",
"BackwardCFunction",

View File

@ -3,8 +3,10 @@ from typing import List, Tuple
import torch
from torch._vmap_internals import _vmap
from . import forward_ad as fwAD
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
# Utility functions

View File

@ -2,13 +2,13 @@
from typing import Any
import torch
from torch.utils._contextlib import (
_DecoratorContextManager,
_NoParamDecoratorContextManager,
F,
)
__all__ = [
"no_grad",
"enable_grad",

View File

@ -12,6 +12,7 @@ from torch._vmap_internals import _vmap, vmap
from torch.overrides import is_tensor_like
from torch.types import _TensorOrTensors
# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public
# since they have been exposed from before we added `__all__` and we already maintain BC for them
# We should eventually deprecate them and remove them from `__all__`

View File

@ -6,11 +6,9 @@ from typing import Any, Dict, List, Optional
from warnings import warn
import torch
import torch.cuda
from torch._C import _get_privateuse1_backend_name
from torch._C._profiler import _ExperimentalConfig
from torch.autograd import (
_disable_profiler,
_enable_profiler,
@ -36,6 +34,7 @@ from torch.autograd.profiler_util import (
)
from torch.futures import Future
__all__ = [
"profile",
"record_function",

View File

@ -5,7 +5,6 @@ from typing_extensions import deprecated
import torch
import torch.cuda
from torch.autograd import (
_disable_profiler_legacy,
_enable_profiler_legacy,
@ -22,6 +21,7 @@ from torch.autograd.profiler_util import (
MEMORY_EVENT_NAME,
)
__all__ = ["profile"]

View File

@ -2,16 +2,15 @@
import bisect
import itertools
import math
from collections import defaultdict, namedtuple
from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import deprecated
import torch
from torch.autograd import DeviceType
__all__ = [
"EventList",
"FormattedTimesMixin",

View File

@ -2,6 +2,7 @@
import types
from contextlib import contextmanager
# The idea for this parameter is that we forbid bare assignment
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
# test suite, where it's very easy to forget to undo the change

View File

@ -10,6 +10,7 @@ from coremltools.models.neural_network import quantization_utils # type: ignore
import torch
CT_METADATA_VERSION = "com.github.apple.coremltools.version"
CT_METADATA_SOURCE = "com.github.apple.coremltools.source"

View File

@ -5,6 +5,7 @@ from typing import List, Optional
import torch
from torch.backends._nnapi.serializer import _NnapiSerializer
ANEURALNETWORKS_PREFER_LOW_POWER = 0
ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2

View File

@ -1,5 +1,6 @@
import torch
__all__ = [
"get_cpu_capability",
]

View File

@ -1,11 +1,11 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Union
from typing_extensions import deprecated
import torch
__all__ = [
"is_built",
"cuFFTPlanCacheAttrContextProp",
@ -262,6 +262,7 @@ def preferred_blas_library(
from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
# Set the __module__ attribute
SDPAParams.__module__ = "torch.backends.cuda"
SDPAParams.__name__ = "SDPAParams"

View File

@ -8,6 +8,7 @@ from typing import Optional
import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
try:
from torch._C import _cudnn
except ImportError:

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import torch.cuda
try:
from torch._C import _cudnn
except ImportError:

View File

@ -2,6 +2,7 @@
# and nn.TransformerEncoder
import torch
_is_fastpath_enabled: bool = True

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch

View File

@ -1,10 +1,10 @@
# mypy: allow-untyped-defs
from functools import lru_cache as _lru_cache
from typing import Optional
from typing import Optional, TYPE_CHECKING
import torch
from ...library import Library as _Library
from torch.library import Library as _Library
__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
@ -43,13 +43,13 @@ _lib: Optional[_Library] = None
def _init():
r"""Register prims as implementation of var_mean and group_norm."""
global _lib
if is_built() is False or _lib is not None:
return
from ..._decomp.decompositions import (
native_group_norm_backward as _native_group_norm_backward,
)
from ..._refs import native_group_norm as _native_group_norm
_lib = _Library("aten", "IMPL")
_lib.impl("native_group_norm", _native_group_norm, "MPS")
_lib.impl("native_group_norm_backward", _native_group_norm_backward, "MPS")
if _lib is not None or not is_built():
return
from torch._decomp.decompositions import native_group_norm_backward
from torch._refs import native_group_norm
_lib = _Library("aten", "IMPL") # noqa: TOR901
_lib.impl("native_group_norm", native_group_norm, "MPS")
_lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")

View File

@ -4,6 +4,7 @@ from contextlib import contextmanager
import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
__all__ = ["is_available", "flags", "set_flags"]

View File

@ -7,6 +7,7 @@ from typing import Any
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
try:
import opt_einsum as _opt_einsum # type: ignore[import]
except ImportError:

View File

@ -140,6 +140,7 @@ from torch.distributed.elastic.multiprocessing import (
Std,
)
format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=format_str)
logger = logging.getLogger(__name__)

View File

@ -4,6 +4,7 @@ from typing_extensions import deprecated
import torch
__all__ = ["autocast"]

View File

@ -2,6 +2,7 @@ from typing_extensions import deprecated
import torch
__all__ = ["GradScaler"]

View File

@ -13,6 +13,7 @@ import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
torch._lazy.ts_backend.init()

View File

@ -11,11 +11,8 @@ It is lazily initialized, so you can always import it, and use
:ref:`cuda-semantics` has more details about working with CUDA.
"""
import contextlib
import importlib
import os
import sys
import threading
import traceback
import warnings
@ -24,9 +21,10 @@ from typing import Any, Callable, cast, List, Optional, Tuple, Union
import torch
import torch._C
from torch import device as _device
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
from torch.types import Device
from .. import device as _device
from .._utils import _dummy_type, _LazySeedTracker, classproperty
from ._utils import _get_device_index
from .graphs import (
CUDAGraph,
@ -37,6 +35,7 @@ from .graphs import (
)
from .streams import Event, ExternalStream, Stream
try:
from torch._C import _cudart # type: ignore[attr-defined]
except ImportError:
@ -1285,10 +1284,9 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int
from .memory import * # noqa: F403
from .random import * # noqa: F403
################################################################################
# Define Storage and Tensor classes
################################################################################
@ -1537,6 +1535,7 @@ _lazy_call(_register_triton_kernels)
from . import amp, jiterator, nvtx, profiler, sparse, tunable
__all__ = [
# Typed storage and tensors
"BFloat16Storage",

View File

@ -2,6 +2,7 @@ from .autocast_mode import autocast, custom_bwd, custom_fwd
from .common import amp_definitely_not_available
from .grad_scaler import GradScaler
__all__ = [
"amp_definitely_not_available",
"autocast",

View File

@ -5,6 +5,7 @@ from typing_extensions import deprecated
import torch
__all__ = ["autocast", "custom_fwd", "custom_bwd"]

View File

@ -3,6 +3,7 @@ from importlib.util import find_spec
import torch
__all__ = ["amp_definitely_not_available"]

View File

@ -5,6 +5,7 @@ import torch
# We need to keep this unused import for BC reasons
from torch.amp.grad_scaler import OptState # noqa: F401
__all__ = ["GradScaler"]

View File

@ -8,6 +8,7 @@ from torch.nn.parallel.comm import (
scatter,
)
__all__ = [
"broadcast",
"broadcast_coalesced",

View File

@ -5,6 +5,7 @@ from typing import Callable, List
import torch
from torch import Tensor
__all__: List[str] = []

View File

@ -8,15 +8,14 @@ import pickle
import sys
import warnings
from inspect import signature
from typing import Any, Dict, Optional, Tuple, Union
from typing_extensions import deprecated
import torch
from torch import _C
from torch._utils import _dummy_type
from torch.types import Device
from .._utils import _dummy_type
from . import (
_get_amdsmi_device_index,
_get_device_index,
@ -24,9 +23,9 @@ from . import (
_lazy_init,
is_initialized,
)
from ._memory_viz import memory as _memory, segments as _segments
__all__ = [
"caching_allocator_alloc",
"caching_allocator_delete",

View File

@ -3,6 +3,7 @@ r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profilin
from contextlib import contextmanager
try:
from torch._C import _nvtx
except ImportError:

View File

@ -3,8 +3,10 @@ import contextlib
import tempfile
import torch
from . import check_error, cudart
__all__ = ["init", "start", "stop", "profile"]
DEFAULT_FLAGS = [

View File

@ -2,9 +2,11 @@
from typing import Iterable, List, Union
import torch
from .. import Tensor
from torch import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count
__all__ = [
"get_rng_state",
"get_rng_state_all",

View File

@ -3,7 +3,7 @@ import ctypes
import torch
from torch._streambase import _EventBase, _StreamBase
from .._utils import _dummy_type
from torch._utils import _dummy_type
if not hasattr(torch._C, "_CudaStreamBase"):

View File

@ -25,10 +25,8 @@ from typing import (
import torch
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
@ -36,6 +34,7 @@ from torch.utils._pytree import (
UnflattenFunc,
)
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow

View File

@ -4,6 +4,7 @@ from typing import List
import torch
from torch._higher_order_ops.effects import _get_schema, with_effects
from .exported_program import ExportedProgram
from .graph_signature import (
CustomObjArgument,

View File

@ -13,7 +13,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch._dynamo
import torch.fx
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.exc import UserError, UserErrorType
@ -44,11 +43,9 @@ from torch._export.wrappers import _wrap_submodules
from torch._functorch._aot_autograd.traced_function_transforms import (
create_functional_call,
)
from torch._functorch._aot_autograd.utils import create_tree_flattened_fn
from torch._functorch.aot_autograd import aot_export_module
from torch._guards import detect_fake_mode
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._utils_internal import log_export_usage
@ -70,7 +67,6 @@ from torch.utils._pytree import TreeSpec
from torch.utils._sympy.value_ranges import ValueRangeError
from ._safeguard import AutogradStateOpsFailSafeguard
from .exported_program import (
_disable_prexisiting_fake_mode,
ExportedProgram,
@ -89,6 +85,7 @@ from .graph_signature import (
TokenArgument,
)
log = logging.getLogger(__name__)

View File

@ -10,7 +10,6 @@ from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from ._remove_effect_tokens_pass import _is_impure_node, _remove_effect_tokens
from .exported_program import (
ExportedProgram,
ExportGraphSignature,

View File

@ -17,12 +17,12 @@ from torch.utils._pytree import (
from .exported_program import ExportedProgram
if TYPE_CHECKING:
from sympy import Symbol
from torch._guards import Source
from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
__all__ = [
"Constraint",

View File

@ -24,12 +24,11 @@ from typing import (
)
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._library.fake_class_registry import FakeScriptObject
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.immutable_collections import immutable_dict, immutable_list
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
@ -41,16 +40,12 @@ if TYPE_CHECKING:
import torch
import torch.utils._pytree as pytree
from torch._export.verifier import Verifier
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.fx._compatibility import compatibility
from torch.fx._utils import first_call_function_nn_module_stack
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
@ -69,6 +64,7 @@ from .graph_signature import ( # noqa: F401
TokenArgument,
)
__all__ = [
"ExportedProgram",
"ModuleCallEntry",
@ -331,7 +327,6 @@ def _decompose_and_get_gm_with_new_signature_constants(
from torch._export.passes.lift_constants_pass import ConstantAttrMap
from torch._functorch.aot_autograd import aot_export_module
from torch._guards import detect_fake_mode
from torch.export._trace import (
_export_to_aten_ir,
_fakify_params_buffers,

View File

@ -9,8 +9,10 @@ from torch.fx.graph_module import (
reduce_package_graph_module,
)
from torch.package import PackageExporter, sys_importer
from ._compatibility import compatibility
_use_lazy_graph_module_flag = False
_force_skip_lazy_graph_module_flag = False

View File

@ -3,9 +3,9 @@ from collections import namedtuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
import torch.return_types
from torch.utils._pytree import PyTree, TreeSpec
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]

View File

@ -3,7 +3,6 @@ import sys
from typing import Dict, Optional
import torch
from torch._logging import LazyString

View File

@ -1,8 +1,8 @@
import os
import sys
from typing import Optional
# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
translation_validation = (
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
@ -76,4 +76,5 @@ use_duck_shape = True
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

View File

@ -2,6 +2,7 @@
import contextlib
from typing import List, Optional, Type
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None

View File

@ -2,11 +2,11 @@ import re
from typing import Any, DefaultDict, Dict, List, Tuple, Union
import numpy as np
import sympy as sp
import torch
square_brackets_pattern = r"\[([^]]+)\]"
parentheses_pattern = r"\((.*?)\)"
s_pattern = r"s\d+"

View File

@ -31,12 +31,12 @@ from torch import ( # noqa: F401
SymFloat,
SymInt,
)
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv

View File

@ -7,6 +7,7 @@ from torch.fx.graph_module import GraphModule
from .graph_drawer import FxGraphDrawer
__all__ = ["GraphTransformObserver"]

View File

@ -4,6 +4,7 @@ import logging
import operator
from typing import Any, Dict, Optional, Set, TYPE_CHECKING
# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
if TYPE_CHECKING:
import sympy
@ -21,6 +22,7 @@ from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.experimental.sym_node import SymNode
from torch.fx.graph_module import GraphModule
log = logging.getLogger(__name__)
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")

View File

@ -1,10 +1,11 @@
from typing import Dict, List, Tuple
from torch.fx import Graph, GraphModule, Node
from torch.fx._compatibility import compatibility
from .matcher_utils import InternalMatch, SubgraphMatcher
__all__ = ["SubgraphMatcherWithNameNodeMap"]

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import warnings
from contextlib import contextmanager
from typing import Any, Iterator
@ -66,9 +65,9 @@ from torch.jit._trace import (
TracerWarning,
TracingCheckError,
)
from torch.utils import set_module
__all__ = [
"Attribute",
"CompilationUnit",

View File

@ -12,9 +12,9 @@ functionalities in `torch.jit`.
import torch
from torch._jit_internal import Future
from torch.jit._builtins import _register_builtin
from torch.utils import set_module
set_module(Future, "torch.jit")

View File

@ -2,9 +2,9 @@
import torch
from torch._jit_internal import _Await
from torch.jit._builtins import _register_builtin
from torch.utils import set_module
set_module(_Await, "torch.jit")

View File

@ -2,14 +2,19 @@
import cmath
import math
import warnings
from collections import OrderedDict
from typing import Dict, Optional
import torch
import torch.backends.cudnn as cudnn
from torch.nn.modules.utils import (
_list_with_default,
_pair,
_quadruple,
_single,
_triple,
)
from ..nn.modules.utils import _list_with_default, _pair, _quadruple, _single, _triple
_builtin_table: Optional[Dict[int, str]] = None

View File

@ -2,6 +2,7 @@
import torch
from torch import Tensor
aten = torch.ops.aten
import inspect
import warnings
@ -10,6 +11,7 @@ from typing_extensions import ParamSpec
from torch.types import Number
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
function_name_set: Set[str] = set()

View File

@ -1,5 +1,6 @@
import torch
add_stat_value = torch.ops.prim.AddStatValue
set_logger = torch._C._logging_set_logger

View File

@ -9,7 +9,6 @@ import warnings
from typing import Dict, List, Set, Type
import torch
import torch._jit_internal as _jit_internal
from torch._sources import fake_range
from torch.jit._builtins import _find_builtin

View File

@ -22,7 +22,6 @@ from torch._jit_internal import _get_model_id, _qualified_name
from torch._utils_internal import log_torchscript_usage
from torch.jit._builtins import _register_builtin
from torch.jit._fuser import _graph_for, _script_method_graph_for
from torch.jit._monkeytype_config import (
JitTypeTraceConfig,
JitTypeTraceStore,
@ -53,6 +52,7 @@ from torch.utils import set_module
from ._serialization import validate_map_location
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined]

View File

@ -2,6 +2,7 @@
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
number = Union[int, float]
# flake8: noqa

View File

@ -10,7 +10,6 @@ functionalities in `torch.jit`.
"""
import contextlib
import copy
import functools
import inspect
@ -28,15 +27,13 @@ from torch._jit_internal import (
get_callable_argument_names,
is_scripting,
)
from torch.autograd import function
from torch.jit._script import _CachedForward, script, ScriptModule
from torch.jit._state import _enabled, _python_cu
from torch.nn import Module
from torch.testing._comparison import default_tolerances
_flatten = torch._C._jit_flatten
_unflatten = torch._C._jit_unflatten

View File

@ -7,12 +7,10 @@ import inspect
import re
import typing
import warnings
from textwrap import dedent
from typing import Type
import torch
from torch._C import (
_GeneratorType,
AnyType,
@ -36,8 +34,7 @@ from torch._C import (
TupleType,
UnionType,
)
from torch._sources import get_source_lines_and_file
from .._jit_internal import ( # type: ignore[attr-defined]
from torch._jit_internal import ( # type: ignore[attr-defined]
_Await,
_qualified_name,
Any,
@ -59,11 +56,14 @@ from .._jit_internal import ( # type: ignore[attr-defined]
Tuple,
Union,
)
from torch._sources import get_source_lines_and_file
from ._state import _get_script_class
if torch.distributed.rpc.is_available():
from torch._C import RRefType
from .._jit_internal import is_rref, RRef
from torch._jit_internal import is_rref, RRef
from torch._ops import OpOverloadPacket

View File

@ -73,6 +73,7 @@ from torch._sources import (
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace
_IS_ASTUNPARSE_INSTALLED = False
try:
import astunparse # type: ignore[import]

View File

@ -2,7 +2,6 @@
import os
import torch
from torch.jit._serialization import validate_map_location

View File

@ -5,6 +5,7 @@ import textwrap
import torch.jit
from torch.jit._builtins import _find_builtin
# this file is for generating documentation using sphinx autodoc
# > help(torch.jit.supported_ops) will also give a nice listed of the
# supported ops programmatically

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
from textwrap import dedent
from typing import Any, Dict
import torch.jit

View File

@ -10,6 +10,7 @@ from torch.masked import _docs
from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
from torch.masked.maskedtensor.creation import as_masked_tensor
if TYPE_CHECKING:
from torch.types import _dtype as DType

View File

@ -5,6 +5,7 @@ from functools import partial
from typing import Any, Callable, Dict, TYPE_CHECKING
import torch
from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
from .core import (
_get_data,

View File

@ -8,7 +8,8 @@ See https://developer.apple.com/documentation/metalperformanceshaders for more d
from typing import Union
import torch
from .. import Tensor
from torch import Tensor
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
@ -146,6 +147,7 @@ def is_available() -> bool:
from . import profiler
from .event import Event
__all__ = [
"device_count",
"get_rng_state",

View File

@ -3,6 +3,7 @@ import contextlib
import torch
__all__ = ["start", "stop", "profile"]

View File

@ -8,13 +8,13 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import device as _device, Tensor
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
from torch.types import Device
from .. import device as _device, Tensor
from .._utils import _dummy_type, _LazySeedTracker, classproperty
from ._utils import _get_device_index
_device_t = Union[_device, str, int, None]
# torch.mtia.Event/Stream is alias of torch.Event/Stream

View File

@ -18,8 +18,10 @@ import multiprocessing
import sys
import torch
from .reductions import init_reductions
__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import sys
__all__ = ["register_after_fork"]
if sys.platform == "win32":

View File

@ -13,6 +13,7 @@ from typing import Optional
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
log = logging.getLogger(__name__)

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
from typing import * # noqa: F403
from typing import Tuple
import torch
@ -6,7 +7,7 @@ from torch._C import DispatchKey, DispatchKeySet
from torch._prims_common import is_expandable_to
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.utils.weak import WeakTensorKeyDictionary
from typing import * # noqa: F403
_tensor_id_counter = 0
_tensor_symint_registry = WeakTensorKeyDictionary()

View File

@ -2,14 +2,15 @@
import functools
import math
import operator
from typing import * # noqa: F403
import torch
import torch.nn.functional as F
from torch.fx.operator_schemas import normalize_function
from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
from .nested_tensor import NestedTensor
from typing import * # noqa: F403
import torch.nn.functional as F
from torch.fx.operator_schemas import normalize_function
__all__: List[Any] = []

View File

@ -17,6 +17,7 @@ from torch.nn.attention import SDPBackend
from .nested_tensor import NestedTensor
log = logging.getLogger(__name__)

View File

@ -1,11 +1,11 @@
# mypy: allow-untyped-defs
from torch.nn.parameter import (
from torch.nn.parameter import ( # usort: skip
Buffer as Buffer,
Parameter as Parameter,
UninitializedBuffer as UninitializedBuffer,
UninitializedParameter as UninitializedParameter,
)
from torch.nn.modules import * # noqa: F403
from torch.nn.modules import * # usort: skip # noqa: F403
from torch.nn import (
attention as attention,
functional as functional,

View File

@ -24,6 +24,7 @@ from torch.fx.experimental.proxy_tensor import (
from torch.nn.attention._utils import _validate_sdpa_input
from torch.utils._pytree import tree_map_only
__all__ = [
"BlockMask",
"flex_attention",

View File

@ -11,7 +11,6 @@ from typing import (
TypeVar,
Union,
)
from typing_extensions import Self
import torch