mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/
and torch/[e-n]*/
(#129769)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
2714adce20
commit
f3fce597e9
@ -51,10 +51,8 @@ ISORT_SKIPLIST = re.compile(
|
||||
# torch/_i*/**
|
||||
# torch/_[j-z]*/**
|
||||
# torch/[a-c]*/**
|
||||
"torch/[a-c]*/**",
|
||||
# torch/d*/**
|
||||
# torch/[e-n]*/**
|
||||
"torch/[e-n]*/**",
|
||||
# torch/[o-z]*/**
|
||||
],
|
||||
),
|
||||
|
@ -18,10 +18,11 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch import _C, _ops, Tensor
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from .. import _C, _library, _ops, autograd, library, Tensor
|
||||
from . import utils
|
||||
from . import autograd, utils
|
||||
|
||||
|
||||
device_types_t = Optional[Union[str, Sequence[str]]]
|
||||
@ -363,10 +364,10 @@ class CustomOpDef:
|
||||
self._backend_fns[device_type] = wrapped_fn
|
||||
return fn
|
||||
|
||||
from torch._library.utils import get_device_arg_index, has_tensor_arg
|
||||
|
||||
if device_types is not None and not has_tensor_arg(self._opoverload._schema):
|
||||
device_arg_index = get_device_arg_index(self._opoverload._schema)
|
||||
if device_types is not None and not utils.has_tensor_arg(
|
||||
self._opoverload._schema
|
||||
):
|
||||
device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
|
||||
if device_arg_index is None:
|
||||
raise ValueError(
|
||||
"Functions without tensor inputs are required to have a `device: torch.device` argument"
|
||||
@ -566,7 +567,7 @@ class CustomOpDef:
|
||||
|
||||
"""
|
||||
schema = self._opoverload._schema
|
||||
if not _library.utils.is_functional_schema(schema):
|
||||
if not utils.is_functional_schema(schema):
|
||||
raise RuntimeError(
|
||||
f"Cannot register autograd formula for non-functional operator "
|
||||
f"{self} with schema {schema}. Please create "
|
||||
@ -593,11 +594,11 @@ class CustomOpDef:
|
||||
schema_str,
|
||||
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
|
||||
)
|
||||
self._opoverload = _library.utils.lookup_op(self._qualname)
|
||||
self._opoverload = utils.lookup_op(self._qualname)
|
||||
|
||||
def fake_impl(*args, **kwargs):
|
||||
if self._abstract_fn is None:
|
||||
if _library.utils.can_generate_trivial_fake_impl(self._opoverload):
|
||||
if utils.can_generate_trivial_fake_impl(self._opoverload):
|
||||
return None
|
||||
raise RuntimeError(
|
||||
f"There was no fake impl registered for {self}. "
|
||||
@ -609,24 +610,24 @@ class CustomOpDef:
|
||||
|
||||
lib._register_fake(self._name, fake_impl, _stacklevel=4)
|
||||
|
||||
autograd_impl = _library.autograd.make_autograd_impl(self._opoverload, self)
|
||||
autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
|
||||
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
|
||||
|
||||
schema = self._opoverload._schema
|
||||
if schema.is_mutable:
|
||||
|
||||
def adinplaceorview_impl(keyset, *args, **kwargs):
|
||||
for arg, val in _library.utils.zip_schema(schema, args, kwargs):
|
||||
for arg, val in utils.zip_schema(schema, args, kwargs):
|
||||
if not arg.alias_info:
|
||||
continue
|
||||
if not arg.alias_info.is_write:
|
||||
continue
|
||||
if isinstance(val, Tensor):
|
||||
autograd.graph.increment_version(val)
|
||||
torch.autograd.graph.increment_version(val)
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for v in val:
|
||||
if isinstance(v, Tensor):
|
||||
autograd.graph.increment_version(v)
|
||||
torch.autograd.graph.increment_version(v)
|
||||
with _C._AutoDispatchBelowADInplaceOrView():
|
||||
return self._opoverload.redispatch(
|
||||
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
|
||||
@ -783,18 +784,20 @@ class CustomOpDef:
|
||||
# decorator.
|
||||
|
||||
|
||||
OPDEF_TO_LIB: Dict[str, "library.Library"] = {}
|
||||
OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
|
||||
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||
|
||||
|
||||
def get_library_allowing_overwrite(namespace: str, name: str) -> "library.Library":
|
||||
def get_library_allowing_overwrite(
|
||||
namespace: str, name: str
|
||||
) -> "torch.library.Library":
|
||||
qualname = f"{namespace}::{name}"
|
||||
|
||||
if qualname in OPDEF_TO_LIB:
|
||||
OPDEF_TO_LIB[qualname]._destroy()
|
||||
del OPDEF_TO_LIB[qualname]
|
||||
|
||||
lib = library.Library(namespace, "FRAGMENT")
|
||||
lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
|
||||
OPDEF_TO_LIB[qualname] = lib
|
||||
return lib
|
||||
|
||||
|
@ -7,14 +7,15 @@ for which gradients should be computed with the ``requires_grad=True`` keyword.
|
||||
As of now, we only support autograd for floating point :class:`Tensor` types (
|
||||
half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union
|
||||
from typing import cast, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch import _vmap_internals
|
||||
from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like
|
||||
from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
|
||||
from .. import _vmap_internals
|
||||
from ..overrides import handle_torch_function, has_torch_function, is_tensor_like
|
||||
|
||||
from . import forward_ad, functional, graph
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||
from .function import Function, NestedIOFunction
|
||||
@ -29,9 +30,9 @@ from .grad_mode import (
|
||||
)
|
||||
from .gradcheck import gradcheck, gradgradcheck
|
||||
from .graph import _engine_run_backward
|
||||
|
||||
from .variable import Variable
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Variable",
|
||||
"Function",
|
||||
@ -575,7 +576,6 @@ from torch._C._autograd import (
|
||||
ProfilerEvent,
|
||||
SavedTensor,
|
||||
)
|
||||
|
||||
from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
|
||||
|
||||
from . import profiler
|
||||
|
@ -5,7 +5,7 @@ from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch._utils
|
||||
from ..function import Function
|
||||
from torch.autograd.function import Function
|
||||
|
||||
|
||||
class Type(Function):
|
||||
|
@ -4,6 +4,7 @@ import warnings
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["detect_anomaly", "set_detect_anomaly"]
|
||||
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from collections import namedtuple
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from .grad_mode import _DecoratorContextManager
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UnpackedDualTensor",
|
||||
"enter_dual_level",
|
||||
|
@ -14,6 +14,7 @@ import torch.utils.hooks as hooks
|
||||
from torch._C import _functions
|
||||
from torch._functorch.autograd_function import custom_function_call
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FunctionCtx",
|
||||
"BackwardCFunction",
|
||||
|
@ -3,8 +3,10 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch._vmap_internals import _vmap
|
||||
|
||||
from . import forward_ad as fwAD
|
||||
|
||||
|
||||
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
|
||||
|
||||
# Utility functions
|
||||
|
@ -2,13 +2,13 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from torch.utils._contextlib import (
|
||||
_DecoratorContextManager,
|
||||
_NoParamDecoratorContextManager,
|
||||
F,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"no_grad",
|
||||
"enable_grad",
|
||||
|
@ -12,6 +12,7 @@ from torch._vmap_internals import _vmap, vmap
|
||||
from torch.overrides import is_tensor_like
|
||||
from torch.types import _TensorOrTensors
|
||||
|
||||
|
||||
# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public
|
||||
# since they have been exposed from before we added `__all__` and we already maintain BC for them
|
||||
# We should eventually deprecate them and remove them from `__all__`
|
||||
|
@ -6,11 +6,9 @@ from typing import Any, Dict, List, Optional
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
|
||||
import torch.cuda
|
||||
from torch._C import _get_privateuse1_backend_name
|
||||
from torch._C._profiler import _ExperimentalConfig
|
||||
|
||||
from torch.autograd import (
|
||||
_disable_profiler,
|
||||
_enable_profiler,
|
||||
@ -36,6 +34,7 @@ from torch.autograd.profiler_util import (
|
||||
)
|
||||
from torch.futures import Future
|
||||
|
||||
|
||||
__all__ = [
|
||||
"profile",
|
||||
"record_function",
|
||||
|
@ -5,7 +5,6 @@ from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
||||
from torch.autograd import (
|
||||
_disable_profiler_legacy,
|
||||
_enable_profiler_legacy,
|
||||
@ -22,6 +21,7 @@ from torch.autograd.profiler_util import (
|
||||
MEMORY_EVENT_NAME,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["profile"]
|
||||
|
||||
|
||||
|
@ -2,16 +2,15 @@
|
||||
import bisect
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
from operator import attrgetter
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch.autograd import DeviceType
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EventList",
|
||||
"FormattedTimesMixin",
|
||||
|
@ -2,6 +2,7 @@
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
# The idea for this parameter is that we forbid bare assignment
|
||||
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
|
||||
# test suite, where it's very easy to forget to undo the change
|
||||
|
@ -10,6 +10,7 @@ from coremltools.models.neural_network import quantization_utils # type: ignore
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
CT_METADATA_VERSION = "com.github.apple.coremltools.version"
|
||||
CT_METADATA_SOURCE = "com.github.apple.coremltools.source"
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
from torch.backends._nnapi.serializer import _NnapiSerializer
|
||||
|
||||
|
||||
ANEURALNETWORKS_PREFER_LOW_POWER = 0
|
||||
ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
|
||||
ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_cpu_capability",
|
||||
]
|
||||
|
@ -1,11 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
|
||||
from typing import Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = [
|
||||
"is_built",
|
||||
"cuFFTPlanCacheAttrContextProp",
|
||||
@ -262,6 +262,7 @@ def preferred_blas_library(
|
||||
|
||||
from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
|
||||
|
||||
|
||||
# Set the __module__ attribute
|
||||
SDPAParams.__module__ = "torch.backends.cuda"
|
||||
SDPAParams.__name__ = "SDPAParams"
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional
|
||||
import torch
|
||||
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
|
||||
|
||||
|
||||
try:
|
||||
from torch._C import _cudnn
|
||||
except ImportError:
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch.cuda
|
||||
|
||||
|
||||
try:
|
||||
from torch._C import _cudnn
|
||||
except ImportError:
|
||||
|
@ -2,6 +2,7 @@
|
||||
# and nn.TransformerEncoder
|
||||
import torch
|
||||
|
||||
|
||||
_is_fastpath_enabled: bool = True
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
@ -1,10 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from functools import lru_cache as _lru_cache
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from ...library import Library as _Library
|
||||
from torch.library import Library as _Library
|
||||
|
||||
|
||||
__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
|
||||
|
||||
@ -43,13 +43,13 @@ _lib: Optional[_Library] = None
|
||||
def _init():
|
||||
r"""Register prims as implementation of var_mean and group_norm."""
|
||||
global _lib
|
||||
if is_built() is False or _lib is not None:
|
||||
return
|
||||
from ..._decomp.decompositions import (
|
||||
native_group_norm_backward as _native_group_norm_backward,
|
||||
)
|
||||
from ..._refs import native_group_norm as _native_group_norm
|
||||
|
||||
_lib = _Library("aten", "IMPL")
|
||||
_lib.impl("native_group_norm", _native_group_norm, "MPS")
|
||||
_lib.impl("native_group_norm_backward", _native_group_norm_backward, "MPS")
|
||||
if _lib is not None or not is_built():
|
||||
return
|
||||
|
||||
from torch._decomp.decompositions import native_group_norm_backward
|
||||
from torch._refs import native_group_norm
|
||||
|
||||
_lib = _Library("aten", "IMPL") # noqa: TOR901
|
||||
_lib.impl("native_group_norm", native_group_norm, "MPS")
|
||||
_lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")
|
||||
|
@ -4,6 +4,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
|
||||
|
||||
|
||||
__all__ = ["is_available", "flags", "set_flags"]
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
|
||||
|
||||
|
||||
try:
|
||||
import opt_einsum as _opt_einsum # type: ignore[import]
|
||||
except ImportError:
|
||||
|
@ -140,6 +140,7 @@ from torch.distributed.elastic.multiprocessing import (
|
||||
Std,
|
||||
)
|
||||
|
||||
|
||||
format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=format_str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -4,6 +4,7 @@ from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["autocast"]
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@ from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["GradScaler"]
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
|
||||
|
@ -11,11 +11,8 @@ It is lazily initialized, so you can always import it, and use
|
||||
:ref:`cuda-semantics` has more details about working with CUDA.
|
||||
"""
|
||||
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import warnings
|
||||
@ -24,9 +21,10 @@ from typing import Any, Callable, cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
from torch import device as _device
|
||||
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||
from torch.types import Device
|
||||
from .. import device as _device
|
||||
from .._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||
|
||||
from ._utils import _get_device_index
|
||||
from .graphs import (
|
||||
CUDAGraph,
|
||||
@ -37,6 +35,7 @@ from .graphs import (
|
||||
)
|
||||
from .streams import Event, ExternalStream, Stream
|
||||
|
||||
|
||||
try:
|
||||
from torch._C import _cudart # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
@ -1285,10 +1284,9 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int
|
||||
|
||||
|
||||
from .memory import * # noqa: F403
|
||||
|
||||
|
||||
from .random import * # noqa: F403
|
||||
|
||||
|
||||
################################################################################
|
||||
# Define Storage and Tensor classes
|
||||
################################################################################
|
||||
@ -1537,6 +1535,7 @@ _lazy_call(_register_triton_kernels)
|
||||
|
||||
from . import amp, jiterator, nvtx, profiler, sparse, tunable
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Typed storage and tensors
|
||||
"BFloat16Storage",
|
||||
|
@ -2,6 +2,7 @@ from .autocast_mode import autocast, custom_bwd, custom_fwd
|
||||
from .common import amp_definitely_not_available
|
||||
from .grad_scaler import GradScaler
|
||||
|
||||
|
||||
__all__ = [
|
||||
"amp_definitely_not_available",
|
||||
"autocast",
|
||||
|
@ -5,6 +5,7 @@ from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ from importlib.util import find_spec
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["amp_definitely_not_available"]
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
# We need to keep this unused import for BC reasons
|
||||
from torch.amp.grad_scaler import OptState # noqa: F401
|
||||
|
||||
|
||||
__all__ = ["GradScaler"]
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ from torch.nn.parallel.comm import (
|
||||
scatter,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"broadcast",
|
||||
"broadcast_coalesced",
|
||||
|
@ -5,6 +5,7 @@ from typing import Callable, List
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
|
@ -8,15 +8,14 @@ import pickle
|
||||
import sys
|
||||
import warnings
|
||||
from inspect import signature
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
|
||||
from torch._utils import _dummy_type
|
||||
from torch.types import Device
|
||||
from .._utils import _dummy_type
|
||||
|
||||
from . import (
|
||||
_get_amdsmi_device_index,
|
||||
_get_device_index,
|
||||
@ -24,9 +23,9 @@ from . import (
|
||||
_lazy_init,
|
||||
is_initialized,
|
||||
)
|
||||
|
||||
from ._memory_viz import memory as _memory, segments as _segments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"caching_allocator_alloc",
|
||||
"caching_allocator_delete",
|
||||
|
@ -3,6 +3,7 @@ r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profilin
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
try:
|
||||
from torch._C import _nvtx
|
||||
except ImportError:
|
||||
|
@ -3,8 +3,10 @@ import contextlib
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from . import check_error, cudart
|
||||
|
||||
|
||||
__all__ = ["init", "start", "stop", "profile"]
|
||||
|
||||
DEFAULT_FLAGS = [
|
||||
|
@ -2,9 +2,11 @@
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import torch
|
||||
from .. import Tensor
|
||||
from torch import Tensor
|
||||
|
||||
from . import _lazy_call, _lazy_init, current_device, device_count
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_rng_state",
|
||||
"get_rng_state_all",
|
||||
|
@ -3,7 +3,7 @@ import ctypes
|
||||
|
||||
import torch
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
from .._utils import _dummy_type
|
||||
from torch._utils import _dummy_type
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_CudaStreamBase"):
|
||||
|
@ -25,10 +25,8 @@ from typing import (
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
|
||||
from torch.utils._pytree import (
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
@ -36,6 +34,7 @@ from torch.utils._pytree import (
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# Do not import unconditionally, as they import sympy and importing sympy is very slow
|
||||
|
@ -4,6 +4,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.effects import _get_schema, with_effects
|
||||
|
||||
from .exported_program import ExportedProgram
|
||||
from .graph_signature import (
|
||||
CustomObjArgument,
|
||||
|
@ -13,7 +13,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.fx
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.exc import UserError, UserErrorType
|
||||
@ -44,11 +43,9 @@ from torch._export.wrappers import _wrap_submodules
|
||||
from torch._functorch._aot_autograd.traced_function_transforms import (
|
||||
create_functional_call,
|
||||
)
|
||||
|
||||
from torch._functorch._aot_autograd.utils import create_tree_flattened_fn
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch._utils_internal import log_export_usage
|
||||
@ -70,7 +67,6 @@ from torch.utils._pytree import TreeSpec
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
||||
from ._safeguard import AutogradStateOpsFailSafeguard
|
||||
|
||||
from .exported_program import (
|
||||
_disable_prexisiting_fake_mode,
|
||||
ExportedProgram,
|
||||
@ -89,6 +85,7 @@ from .graph_signature import (
|
||||
TokenArgument,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -10,7 +10,6 @@ from torch.export.unflatten import _assign_attr, _AttrKind
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
from ._remove_effect_tokens_pass import _is_impure_node, _remove_effect_tokens
|
||||
|
||||
from .exported_program import (
|
||||
ExportedProgram,
|
||||
ExportGraphSignature,
|
||||
|
@ -17,12 +17,12 @@ from torch.utils._pytree import (
|
||||
|
||||
from .exported_program import ExportedProgram
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sympy import Symbol
|
||||
|
||||
from torch._guards import Source
|
||||
|
||||
from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
|
||||
|
||||
__all__ = [
|
||||
"Constraint",
|
||||
|
@ -24,12 +24,11 @@ from typing import (
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
||||
@ -41,16 +40,12 @@ if TYPE_CHECKING:
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._export.verifier import Verifier
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
|
||||
from torch.export._tree_utils import is_equivalent, reorder_kwargs
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
from torch.fx._utils import first_call_function_nn_module_stack
|
||||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
|
||||
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
@ -69,6 +64,7 @@ from .graph_signature import ( # noqa: F401
|
||||
TokenArgument,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ExportedProgram",
|
||||
"ModuleCallEntry",
|
||||
@ -331,7 +327,6 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
from torch._export.passes.lift_constants_pass import ConstantAttrMap
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
from torch.export._trace import (
|
||||
_export_to_aten_ir,
|
||||
_fakify_params_buffers,
|
||||
|
@ -9,8 +9,10 @@ from torch.fx.graph_module import (
|
||||
reduce_package_graph_module,
|
||||
)
|
||||
from torch.package import PackageExporter, sys_importer
|
||||
|
||||
from ._compatibility import compatibility
|
||||
|
||||
|
||||
_use_lazy_graph_module_flag = False
|
||||
_force_skip_lazy_graph_module_flag = False
|
||||
|
||||
|
@ -3,9 +3,9 @@ from collections import namedtuple
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
|
||||
|
||||
import torch.return_types
|
||||
|
||||
from torch.utils._pytree import PyTree, TreeSpec
|
||||
|
||||
|
||||
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
|
||||
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
|
||||
|
||||
|
@ -3,7 +3,6 @@ import sys
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torch._logging import LazyString
|
||||
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
|
||||
translation_validation = (
|
||||
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
|
||||
@ -76,4 +76,5 @@ use_duck_shape = True
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
@ -2,6 +2,7 @@
|
||||
import contextlib
|
||||
from typing import List, Optional, Type
|
||||
|
||||
|
||||
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
|
||||
|
||||
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
|
||||
|
@ -2,11 +2,11 @@ import re
|
||||
from typing import Any, DefaultDict, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import sympy as sp
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
square_brackets_pattern = r"\[([^]]+)\]"
|
||||
parentheses_pattern = r"\((.*?)\)"
|
||||
s_pattern = r"s\d+"
|
||||
|
@ -31,12 +31,12 @@ from torch import ( # noqa: F401
|
||||
SymFloat,
|
||||
SymInt,
|
||||
)
|
||||
|
||||
from torch.fx.experimental._sym_dispatch_mode import (
|
||||
handle_sym_dispatch,
|
||||
sym_function_mode,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
|
@ -7,6 +7,7 @@ from torch.fx.graph_module import GraphModule
|
||||
|
||||
from .graph_drawer import FxGraphDrawer
|
||||
|
||||
|
||||
__all__ = ["GraphTransformObserver"]
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import operator
|
||||
from typing import Any, Dict, Optional, Set, TYPE_CHECKING
|
||||
|
||||
|
||||
# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
|
||||
if TYPE_CHECKING:
|
||||
import sympy
|
||||
@ -21,6 +22,7 @@ from torch.fx.experimental.proxy_tensor import py_sym_types
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
from .matcher_utils import InternalMatch, SubgraphMatcher
|
||||
|
||||
|
||||
__all__ = ["SubgraphMatcherWithNameNodeMap"]
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Iterator
|
||||
|
||||
@ -66,9 +65,9 @@ from torch.jit._trace import (
|
||||
TracerWarning,
|
||||
TracingCheckError,
|
||||
)
|
||||
|
||||
from torch.utils import set_module
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Attribute",
|
||||
"CompilationUnit",
|
||||
|
@ -12,9 +12,9 @@ functionalities in `torch.jit`.
|
||||
import torch
|
||||
from torch._jit_internal import Future
|
||||
from torch.jit._builtins import _register_builtin
|
||||
|
||||
from torch.utils import set_module
|
||||
|
||||
|
||||
set_module(Future, "torch.jit")
|
||||
|
||||
|
||||
|
@ -2,9 +2,9 @@
|
||||
import torch
|
||||
from torch._jit_internal import _Await
|
||||
from torch.jit._builtins import _register_builtin
|
||||
|
||||
from torch.utils import set_module
|
||||
|
||||
|
||||
set_module(_Await, "torch.jit")
|
||||
|
||||
|
||||
|
@ -2,14 +2,19 @@
|
||||
import cmath
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.nn.modules.utils import (
|
||||
_list_with_default,
|
||||
_pair,
|
||||
_quadruple,
|
||||
_single,
|
||||
_triple,
|
||||
)
|
||||
|
||||
from ..nn.modules.utils import _list_with_default, _pair, _quadruple, _single, _triple
|
||||
|
||||
_builtin_table: Optional[Dict[int, str]] = None
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
import inspect
|
||||
import warnings
|
||||
@ -10,6 +11,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
from torch.types import Number
|
||||
|
||||
|
||||
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
|
||||
function_name_set: Set[str] = set()
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
add_stat_value = torch.ops.prim.AddStatValue
|
||||
|
||||
set_logger = torch._C._logging_set_logger
|
||||
|
@ -9,7 +9,6 @@ import warnings
|
||||
from typing import Dict, List, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
import torch._jit_internal as _jit_internal
|
||||
from torch._sources import fake_range
|
||||
from torch.jit._builtins import _find_builtin
|
||||
|
@ -22,7 +22,6 @@ from torch._jit_internal import _get_model_id, _qualified_name
|
||||
from torch._utils_internal import log_torchscript_usage
|
||||
from torch.jit._builtins import _register_builtin
|
||||
from torch.jit._fuser import _graph_for, _script_method_graph_for
|
||||
|
||||
from torch.jit._monkeytype_config import (
|
||||
JitTypeTraceConfig,
|
||||
JitTypeTraceStore,
|
||||
@ -53,6 +52,7 @@ from torch.utils import set_module
|
||||
|
||||
from ._serialization import validate_map_location
|
||||
|
||||
|
||||
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
|
||||
|
||||
torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined]
|
||||
|
@ -2,6 +2,7 @@
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
number = Union[int, float]
|
||||
# flake8: noqa
|
||||
|
||||
|
@ -10,7 +10,6 @@ functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
@ -28,15 +27,13 @@ from torch._jit_internal import (
|
||||
get_callable_argument_names,
|
||||
is_scripting,
|
||||
)
|
||||
|
||||
from torch.autograd import function
|
||||
from torch.jit._script import _CachedForward, script, ScriptModule
|
||||
|
||||
from torch.jit._state import _enabled, _python_cu
|
||||
from torch.nn import Module
|
||||
|
||||
from torch.testing._comparison import default_tolerances
|
||||
|
||||
|
||||
_flatten = torch._C._jit_flatten
|
||||
_unflatten = torch._C._jit_unflatten
|
||||
|
||||
|
@ -7,12 +7,10 @@ import inspect
|
||||
import re
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
from textwrap import dedent
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
|
||||
from torch._C import (
|
||||
_GeneratorType,
|
||||
AnyType,
|
||||
@ -36,8 +34,7 @@ from torch._C import (
|
||||
TupleType,
|
||||
UnionType,
|
||||
)
|
||||
from torch._sources import get_source_lines_and_file
|
||||
from .._jit_internal import ( # type: ignore[attr-defined]
|
||||
from torch._jit_internal import ( # type: ignore[attr-defined]
|
||||
_Await,
|
||||
_qualified_name,
|
||||
Any,
|
||||
@ -59,11 +56,14 @@ from .._jit_internal import ( # type: ignore[attr-defined]
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from torch._sources import get_source_lines_and_file
|
||||
|
||||
from ._state import _get_script_class
|
||||
|
||||
|
||||
if torch.distributed.rpc.is_available():
|
||||
from torch._C import RRefType
|
||||
from .._jit_internal import is_rref, RRef
|
||||
from torch._jit_internal import is_rref, RRef
|
||||
|
||||
from torch._ops import OpOverloadPacket
|
||||
|
||||
|
@ -73,6 +73,7 @@ from torch._sources import (
|
||||
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
|
||||
from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace
|
||||
|
||||
|
||||
_IS_ASTUNPARSE_INSTALLED = False
|
||||
try:
|
||||
import astunparse # type: ignore[import]
|
||||
|
@ -2,7 +2,6 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from torch.jit._serialization import validate_map_location
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ import textwrap
|
||||
import torch.jit
|
||||
from torch.jit._builtins import _find_builtin
|
||||
|
||||
|
||||
# this file is for generating documentation using sphinx autodoc
|
||||
# > help(torch.jit.supported_ops) will also give a nice listed of the
|
||||
# supported ops programmatically
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from textwrap import dedent
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch.jit
|
||||
|
@ -10,6 +10,7 @@ from torch.masked import _docs
|
||||
from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
|
||||
from torch.masked.maskedtensor.creation import as_masked_tensor
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _dtype as DType
|
||||
|
||||
|
@ -5,6 +5,7 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
|
||||
from .core import (
|
||||
_get_data,
|
||||
|
@ -8,7 +8,8 @@ See https://developer.apple.com/documentation/metalperformanceshaders for more d
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from .. import Tensor
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
|
||||
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
|
||||
@ -146,6 +147,7 @@ def is_available() -> bool:
|
||||
from . import profiler
|
||||
from .event import Event
|
||||
|
||||
|
||||
__all__ = [
|
||||
"device_count",
|
||||
"get_rng_state",
|
||||
|
@ -3,6 +3,7 @@ import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["start", "stop", "profile"]
|
||||
|
||||
|
||||
|
@ -8,13 +8,13 @@ import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch import device as _device, Tensor
|
||||
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||
from torch.types import Device
|
||||
|
||||
from .. import device as _device, Tensor
|
||||
from .._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||
from ._utils import _get_device_index
|
||||
|
||||
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
# torch.mtia.Event/Stream is alias of torch.Event/Stream
|
||||
|
@ -18,8 +18,10 @@ import multiprocessing
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .reductions import init_reductions
|
||||
|
||||
|
||||
__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = ["register_after_fork"]
|
||||
|
||||
if sys.platform == "win32":
|
||||
|
@ -13,6 +13,7 @@ from typing import Optional
|
||||
|
||||
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import * # noqa: F403
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@ -6,7 +7,7 @@ from torch._C import DispatchKey, DispatchKeySet
|
||||
from torch._prims_common import is_expandable_to
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
from typing import * # noqa: F403
|
||||
|
||||
|
||||
_tensor_id_counter = 0
|
||||
_tensor_symint_registry = WeakTensorKeyDictionary()
|
||||
|
@ -2,14 +2,15 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
|
||||
|
||||
from .nested_tensor import NestedTensor
|
||||
from typing import * # noqa: F403
|
||||
import torch.nn.functional as F
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
|
||||
|
||||
__all__: List[Any] = []
|
||||
|
||||
|
@ -17,6 +17,7 @@ from torch.nn.attention import SDPBackend
|
||||
|
||||
from .nested_tensor import NestedTensor
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch.nn.parameter import (
|
||||
from torch.nn.parameter import ( # usort: skip
|
||||
Buffer as Buffer,
|
||||
Parameter as Parameter,
|
||||
UninitializedBuffer as UninitializedBuffer,
|
||||
UninitializedParameter as UninitializedParameter,
|
||||
)
|
||||
from torch.nn.modules import * # noqa: F403
|
||||
from torch.nn.modules import * # usort: skip # noqa: F403
|
||||
from torch.nn import (
|
||||
attention as attention,
|
||||
functional as functional,
|
||||
|
@ -24,6 +24,7 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
from torch.nn.attention._utils import _validate_sdpa_input
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BlockMask",
|
||||
"flex_attention",
|
||||
|
@ -11,7 +11,6 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
Reference in New Issue
Block a user