mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/
(#129771)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771 Approved by: https://github.com/justinchuby, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
c59f3fff52
commit
30293319a8
@ -56,7 +56,6 @@ ISORT_SKIPLIST = re.compile(
|
||||
# torch/[e-n]*/**
|
||||
"torch/[e-n]*/**",
|
||||
# torch/[o-z]*/**
|
||||
"torch/[o-z]*/**",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
@ -3,6 +3,28 @@ from torch import _C
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
|
||||
from ._exporter_states import ExportTypes
|
||||
from ._internal.onnxruntime import (
|
||||
is_onnxrt_backend_supported,
|
||||
OrtBackend as _OrtBackend,
|
||||
OrtBackendOptions as _OrtBackendOptions,
|
||||
OrtExecutionProvider as _OrtExecutionProvider,
|
||||
)
|
||||
from ._type_utils import JitScalarType
|
||||
from .errors import CheckerError # Backwards compatibility
|
||||
from .utils import (
|
||||
_optimize_graph,
|
||||
_run_symbolic_function,
|
||||
_run_symbolic_method,
|
||||
export,
|
||||
export_to_pretty_string,
|
||||
is_in_onnx_export,
|
||||
register_custom_op_symbolic,
|
||||
select_model_mode_for_export,
|
||||
unregister_custom_op_symbolic,
|
||||
)
|
||||
|
||||
|
||||
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
||||
_deprecation,
|
||||
errors,
|
||||
@ -25,20 +47,6 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
|
||||
utils,
|
||||
)
|
||||
|
||||
from ._exporter_states import ExportTypes
|
||||
from ._type_utils import JitScalarType
|
||||
from .errors import CheckerError # Backwards compatibility
|
||||
from .utils import (
|
||||
_optimize_graph,
|
||||
_run_symbolic_function,
|
||||
_run_symbolic_method,
|
||||
export,
|
||||
export_to_pretty_string,
|
||||
is_in_onnx_export,
|
||||
register_custom_op_symbolic,
|
||||
select_model_mode_for_export,
|
||||
unregister_custom_op_symbolic,
|
||||
)
|
||||
|
||||
from ._internal.exporter import ( # usort: skip. needs to be last to avoid circular import
|
||||
DiagnosticOptions,
|
||||
@ -53,12 +61,6 @@ from ._internal.exporter import ( # usort:skip. needs to be last to avoid circu
|
||||
enable_fake_mode,
|
||||
)
|
||||
|
||||
from ._internal.onnxruntime import (
|
||||
is_onnxrt_backend_supported,
|
||||
OrtBackend as _OrtBackend,
|
||||
OrtBackendOptions as _OrtBackendOptions,
|
||||
OrtExecutionProvider as _OrtExecutionProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Modules
|
||||
|
@ -6,6 +6,7 @@ import warnings
|
||||
from typing import Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
@ -9,6 +9,7 @@ from ._diagnostic import (
|
||||
from ._rules import rules
|
||||
from .infra import levels
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TorchScriptOnnxExportDiagnostic",
|
||||
"ExportDiagnosticEngine",
|
||||
|
@ -7,12 +7,12 @@ import gzip
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
from torch.onnx._internal.diagnostics.infra import formatter, sarif
|
||||
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
|
||||
from torch.utils import cpp_backtrace
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
@ -13,6 +13,7 @@ from typing import Tuple
|
||||
# flake8: noqa
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
|
||||
|
||||
"""
|
||||
GENERATED CODE - DO NOT EDIT DIRECTLY
|
||||
The purpose of generating a class for each rule is to override the `format_message`
|
||||
|
@ -14,6 +14,7 @@ from ._infra import (
|
||||
)
|
||||
from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Diagnostic",
|
||||
"DiagnosticContext",
|
||||
|
@ -4,14 +4,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
import dataclasses
|
||||
import gzip
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
|
@ -7,7 +7,6 @@ import traceback
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from torch._logging import LazyString
|
||||
|
||||
from torch.onnx._internal.diagnostics.infra import sarif
|
||||
|
||||
|
||||
|
@ -97,4 +97,5 @@ from torch.onnx._internal.diagnostics.infra.sarif._version_control_details impor
|
||||
from torch.onnx._internal.diagnostics.infra.sarif._web_request import WebRequest
|
||||
from torch.onnx._internal.diagnostics.infra.sarif._web_response import WebResponse
|
||||
|
||||
|
||||
# flake8: noqa
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Final
|
||||
|
||||
|
||||
SARIF_VERSION: Final = "2.1.0"
|
||||
SARIF_SCHEMA_LINK: Final = "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
|
||||
# flake8: noqa
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import inspect
|
||||
import traceback
|
||||
from typing import Any, Callable, Mapping, Sequence
|
||||
|
@ -4,12 +4,10 @@ from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (O
|
||||
)
|
||||
|
||||
import abc
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@ -27,11 +25,9 @@ from typing import (
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
||||
import torch._ops
|
||||
import torch.export as torch_export
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch.onnx._internal import io_adapter
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
from torch.onnx._internal.fx import (
|
||||
@ -41,6 +37,7 @@ from torch.onnx._internal.fx import (
|
||||
serialization as fx_serialization,
|
||||
)
|
||||
|
||||
|
||||
# We can only import onnx from this module in a type-checking context to ensure that
|
||||
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
|
||||
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
|
||||
@ -48,6 +45,7 @@ if TYPE_CHECKING:
|
||||
import io
|
||||
|
||||
import onnx
|
||||
|
||||
import onnxruntime # type: ignore[import]
|
||||
import onnxscript # type: ignore[import]
|
||||
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
|
||||
@ -55,7 +53,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
from torch.onnx._internal.fx import diagnostics
|
||||
|
||||
_DEFAULT_OPSET_VERSION: Final[int] = 18
|
||||
|
@ -2,23 +2,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import difflib
|
||||
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
|
||||
|
||||
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .unsupported_nodes import UnsupportedFxNodesAnalysis
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UnsupportedFxNodesAnalysis",
|
||||
]
|
||||
|
@ -15,7 +15,6 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
|
||||
@ -26,6 +25,7 @@ from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-
|
||||
import torch
|
||||
from torch._decomp import decompositions
|
||||
|
||||
|
||||
_NEW_OP_NAMESPACE: str = "onnx_export"
|
||||
"""The namespace for the custom operator."""
|
||||
|
||||
|
@ -8,7 +8,6 @@ from typing import Callable
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.fx
|
||||
|
||||
from torch.onnx._internal.fx import registration
|
||||
|
||||
|
||||
|
@ -2,9 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import onnxscript # type: ignore[import]
|
||||
@ -17,6 +15,7 @@ from torch.onnx._internal.diagnostics import infra
|
||||
from torch.onnx._internal.diagnostics.infra import decorator, formatter
|
||||
from torch.onnx._internal.fx import registration, type_utils as fx_type_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import logging
|
||||
|
||||
|
@ -16,7 +16,6 @@ from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.onnx import _type_utils as jit_type_utils
|
||||
|
||||
from torch.onnx._internal.fx import (
|
||||
_pass,
|
||||
diagnostics,
|
||||
|
@ -2,16 +2,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Any, Callable, Mapping, Sequence
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.onnx
|
||||
|
||||
import torch.onnx._internal.fx.passes as passes
|
||||
from torch.onnx._internal import exporter, io_adapter
|
||||
|
||||
|
||||
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
|
||||
# data can flow through those functions. Python functions (e.g., `torch.arange`)
|
||||
# not defined by pybind11 in C++ do not go though Python dispatcher, so
|
||||
|
@ -11,13 +11,13 @@ from typing import Any, Callable, Sequence, TYPE_CHECKING
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.fx
|
||||
|
||||
from torch.onnx._internal.fx import (
|
||||
diagnostics,
|
||||
registration,
|
||||
type_utils as fx_type_utils,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnxscript # type: ignore[import]
|
||||
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
|
||||
|
@ -13,7 +13,6 @@ import torch
|
||||
import torch.fx
|
||||
from torch.fx.experimental import symbolic_shapes
|
||||
from torch.onnx import _constants, _type_utils as jit_type_utils
|
||||
|
||||
from torch.onnx._internal.fx import (
|
||||
diagnostics,
|
||||
fx_onnx_interpreter,
|
||||
|
@ -5,6 +5,7 @@ from .readability import RestoreParameterAndBufferNames
|
||||
from .type_promotion import InsertTypePromotion
|
||||
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Decompose",
|
||||
"InsertTypePromotion",
|
||||
|
@ -6,9 +6,7 @@ These functions should NOT be directly invoked outside of `passes` package.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
|
||||
import re
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch.fx
|
||||
|
@ -2,7 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
from typing import Callable, Mapping, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@ -12,6 +11,7 @@ from torch.fx.experimental import proxy_tensor
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.onnx._internal.fx.passes import _utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.fx
|
||||
from torch._subclasses import fake_tensor
|
||||
|
@ -2,7 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@ -14,6 +13,7 @@ from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.onnx._internal.fx.passes import _utils
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
@ -2,19 +2,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import operator
|
||||
|
||||
from typing import Any, Dict, Final, Generator, Iterator, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
_FX_TRACER_NN_MODULE_META_TYPE = Tuple[str, type]
|
||||
"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer."""
|
||||
_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict
|
||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
|
||||
|
||||
|
@ -3,20 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.fx
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
from torch import _prims_common, _refs
|
||||
|
||||
from torch._prims_common import (
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
wrappers as _prims_common_wrappers,
|
||||
@ -24,15 +20,16 @@ from torch._prims_common import (
|
||||
from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
|
||||
from torch._refs.nn import functional as _functional_refs
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
|
||||
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
|
||||
from torch.utils import _python_dispatch, _pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO(bowbao): move to type utils.
|
||||
|
@ -4,9 +4,9 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from torch.onnx._internal.fx import _pass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.fx
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import List, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import io
|
||||
|
||||
@ -16,7 +17,6 @@ def has_safetensors_and_transformers():
|
||||
# safetensors is not an exporter requirement, but needed for some huggingface models
|
||||
import safetensors # type: ignore[import] # noqa: F401
|
||||
import transformers # type: ignore[import] # noqa: F401
|
||||
|
||||
from safetensors import torch as safetensors_torch # noqa: F401
|
||||
|
||||
return True
|
||||
|
@ -13,6 +13,7 @@ import torch.fx
|
||||
from torch.onnx._internal import exporter, io_adapter
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.onnx
|
||||
from torch.export.exported_program import ExportedProgram
|
||||
|
@ -15,11 +15,13 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy
|
||||
|
||||
import onnx
|
||||
|
||||
import torch
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004
|
||||
|
||||
|
@ -13,9 +13,9 @@ from typing import (
|
||||
|
||||
import torch
|
||||
import torch.export as torch_export
|
||||
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import inspect
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Utilities for manipulating the torch.Graph object and the torchscript."""
|
||||
from __future__ import annotations
|
||||
|
||||
# TODO(justinchuby): Move more of the symbolic helper functions here and expose
|
||||
# them to the user.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
import typing
|
||||
|
@ -4,7 +4,6 @@ import dataclasses
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -33,6 +32,7 @@ from torch.fx.passes.operator_support import OperatorSupport
|
||||
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
|
||||
from torch.utils import _pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
|
||||
@ -929,7 +929,7 @@ class OrtBackend:
|
||||
try:
|
||||
from onnxscript import optimizer # type: ignore[import]
|
||||
from onnxscript.rewriter import ( # type: ignore[import]
|
||||
onnxruntime as ort_rewriter, # type: ignore[import]
|
||||
onnxruntime as ort_rewriter,
|
||||
)
|
||||
|
||||
onnx_model = optimizer.optimize(onnx_model)
|
||||
@ -1112,7 +1112,6 @@ class OrtBackend:
|
||||
the ``compile`` method is invoked directly."""
|
||||
if self._options.use_aot_autograd:
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
|
||||
return aot_autograd(
|
||||
|
@ -10,6 +10,7 @@ import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx import errors
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# Hack to help mypy to recognize torch._C.Value
|
||||
from torch import _C # noqa: F401
|
||||
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal import diagnostics
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import _C
|
||||
|
||||
|
@ -19,6 +19,7 @@ from torch.onnx import _constants, _type_utils, errors, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from torch.types import Number
|
||||
|
||||
|
@ -23,6 +23,7 @@ from torch.onnx import (
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
|
@ -21,6 +21,7 @@ from torch.onnx import (
|
||||
)
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
|
@ -25,6 +25,7 @@ from torch.onnx import _constants, _type_utils, symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
__all__ = [
|
||||
"hardswish",
|
||||
"tril",
|
||||
|
@ -33,6 +33,7 @@ from torch import _C
|
||||
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
|
||||
|
||||
|
||||
|
@ -36,6 +36,7 @@ from torch.nn.functional import (
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper, utils
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
||||
|
||||
|
||||
|
@ -25,6 +25,7 @@ from torch import _C
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
||||
|
@ -27,6 +27,7 @@ from torch import _C
|
||||
from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
|
@ -26,6 +26,7 @@ Size
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
|
@ -24,11 +24,11 @@ New operators:
|
||||
import functools
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
|
@ -39,6 +39,7 @@ from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
|
||||
|
||||
block_listed_operators = (
|
||||
|
@ -27,6 +27,7 @@ from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_h
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import Number
|
||||
|
||||
|
@ -30,6 +30,7 @@ from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import onnx_proto_utils
|
||||
from torch.types import Number
|
||||
|
||||
|
||||
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
||||
|
||||
_NumericType = Union[Number, torch.Tensor, np.ndarray]
|
||||
|
@ -23,6 +23,7 @@ from torch.optim.rprop import Rprop
|
||||
from torch.optim.sgd import SGD
|
||||
from torch.optim.sparse_adam import SparseAdam
|
||||
|
||||
|
||||
Adafactor.__module__ = "torch.optim"
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import (
|
||||
_disable_dynamo_if_unsupported,
|
||||
_get_scalar_dtype,
|
||||
@ -12,6 +13,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Adafactor", "adafactor"]
|
||||
|
||||
|
||||
|
@ -20,6 +20,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Adadelta", "adadelta"]
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
@ -17,6 +18,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Adagrad", "adagrad"]
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
@ -24,6 +25,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Adam", "adam"]
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Adamax", "adamax"]
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
@ -24,6 +25,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AdamW", "adamw"]
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ASGD", "asgd"]
|
||||
|
||||
|
||||
|
@ -3,8 +3,10 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import Optimizer, ParamsT
|
||||
|
||||
|
||||
__all__ = ["LBFGS"]
|
||||
|
||||
|
||||
|
@ -26,6 +26,7 @@ from torch import inf, Tensor
|
||||
|
||||
from .optimizer import Optimizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LambdaLR",
|
||||
"MultiplicativeLR",
|
||||
|
@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
@ -22,6 +23,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["NAdam", "nadam"]
|
||||
|
||||
|
||||
|
@ -35,6 +35,7 @@ from torch.utils._foreach_utils import (
|
||||
)
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
Args: TypeAlias = Tuple[Any, ...]
|
||||
Kwargs: TypeAlias = Dict[str, Any]
|
||||
StateDict: TypeAlias = Dict[str, Any]
|
||||
|
@ -22,6 +22,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RAdam", "radam"]
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
@ -20,6 +21,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RMSprop", "rmsprop"]
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
@ -20,6 +21,7 @@ from .optimizer import (
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Rprop", "rprop"]
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
@ -16,6 +17,7 @@ from .optimizer import (
|
||||
Optimizer,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["SGD", "sgd"]
|
||||
|
||||
|
||||
|
@ -3,9 +3,11 @@ from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from . import _functional as F
|
||||
from .optimizer import _maximize_doc, Optimizer, ParamsT
|
||||
|
||||
|
||||
__all__ = ["SparseAdam"]
|
||||
|
||||
|
||||
|
@ -11,8 +11,10 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim.lr_scheduler import _format_param, LRScheduler
|
||||
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
|
||||
|
||||
from .optimizer import Optimizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AveragedModel",
|
||||
"update_bn",
|
||||
@ -25,6 +27,7 @@ __all__ = [
|
||||
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
|
||||
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
|
||||
|
||||
|
@ -6,6 +6,7 @@ from typing import cast
|
||||
import torch
|
||||
from torch.types import Storage
|
||||
|
||||
|
||||
__serialization_id_record_name__ = ".data/serialization_id"
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import _warnings
|
||||
import os.path
|
||||
|
||||
|
||||
# note: implementations
|
||||
# copied from cpython's import code
|
||||
|
||||
|
@ -4,6 +4,7 @@ See mangling.md for details.
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
_mangle_index = 0
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from ..package_exporter import PackagingError
|
||||
from torch.package.package_exporter import PackagingError
|
||||
|
||||
|
||||
__all__ = ["find_first_use_of_broken_modules"]
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import sys
|
||||
from typing import Any, Callable, Iterable, List, Tuple
|
||||
|
||||
|
||||
__all__ = ["trace_dependencies"]
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import Dict, List
|
||||
|
||||
from .glob_group import GlobGroup, GlobPattern
|
||||
|
||||
|
||||
__all__ = ["Directory"]
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import re
|
||||
from typing import Iterable, Union
|
||||
|
||||
|
||||
GlobPattern = Union[str, Iterable[str]]
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined]
|
||||
from pickle import ( # type: ignore[attr-defined]
|
||||
_getattribute,
|
||||
_Pickler,
|
||||
whichmodule as _pickle_whichmodule,
|
||||
@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ._mangling import demangle, get_mangle_prefix, is_mangled
|
||||
|
||||
|
||||
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
|
||||
|
||||
|
||||
|
@ -39,6 +39,7 @@ from .find_file_dependencies import find_files_source_depends_on
|
||||
from .glob_group import GlobGroup, GlobPattern
|
||||
from .importer import Importer, OrderedImporter, sys_importer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PackagingErrorReason",
|
||||
"EmptyMatchError",
|
||||
|
@ -38,6 +38,7 @@ from ._package_unpickler import PackageUnpickler
|
||||
from .file_structure_representation import _create_directory_from_file_list, Directory
|
||||
from .importer import Importer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .glob_group import GlobPattern
|
||||
|
||||
|
@ -25,6 +25,7 @@ from .profiler import (
|
||||
tensorboard_trace_handler,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"profile",
|
||||
"schedule",
|
||||
|
@ -32,6 +32,7 @@ from torch._C._profiler import (
|
||||
from torch._utils import _element_size
|
||||
from torch.profiler import _utils
|
||||
|
||||
|
||||
KeyAndID = Tuple["Key", int]
|
||||
TensorAndID = Tuple["TensorKey", int]
|
||||
|
||||
|
@ -7,9 +7,9 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
|
||||
from torch.profiler import DeviceType
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.autograd import _KinetoEvent
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
try:
|
||||
from torch._C import _itt
|
||||
except ImportError:
|
||||
|
@ -1,16 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from .quantize import * # noqa: F403
|
||||
from .observer import * # noqa: F403
|
||||
from .qconfig import * # noqa: F403
|
||||
from .fake_quantize import * # noqa: F403
|
||||
from .fuse_modules import fuse_modules
|
||||
from .stubs import * # noqa: F403
|
||||
from .quant_type import * # noqa: F403
|
||||
from .quantize_jit import * # noqa: F403
|
||||
|
||||
# from .quantize_fx import *
|
||||
from .quantization_mappings import * # noqa: F403
|
||||
from .fuser_method_mappings import * # noqa: F403
|
||||
from .observer import * # noqa: F403
|
||||
from .qconfig import * # noqa: F403
|
||||
from .quant_type import * # noqa: F403
|
||||
from .quantization_mappings import * # noqa: F403
|
||||
from .quantize import * # noqa: F403
|
||||
from .quantize_jit import * # noqa: F403
|
||||
from .stubs import * # noqa: F403
|
||||
|
||||
|
||||
def default_eval_fn(model, calib_data):
|
||||
|
@ -15,6 +15,7 @@ from torch.ao.quantization.fx.pattern_utils import (
|
||||
QuantizeHandler,
|
||||
)
|
||||
|
||||
|
||||
# QuantizeHandler.__module__ = _NAMESPACE
|
||||
_register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils"
|
||||
get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils"
|
||||
|
@ -23,6 +23,7 @@ from torch.ao.quantization.fx.quantize_handler import (
|
||||
StandaloneModuleQuantizeHandler,
|
||||
)
|
||||
|
||||
|
||||
QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"
|
||||
BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"
|
||||
CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"
|
||||
|
@ -15,6 +15,7 @@ from .semi_structured import (
|
||||
to_sparse_semi_structured,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _dtype as DType
|
||||
|
||||
|
@ -3,6 +3,7 @@ import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = [
|
||||
"fallback_dispatcher",
|
||||
"semi_sparse_values",
|
||||
|
@ -8,8 +8,10 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from ._triton_ops_meta import get_meta
|
||||
|
||||
|
||||
TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(
|
||||
os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2)
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ from torch.sparse._semi_structured_ops import (
|
||||
semi_sparse_view,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SparseSemiStructuredTensor",
|
||||
"SparseSemiStructuredTensorCUTLASS",
|
||||
|
@ -1,4 +1,5 @@
|
||||
from torch._C import FileCheck as FileCheck
|
||||
|
||||
from . import _utils
|
||||
from ._comparison import assert_allclose, assert_close as assert_close
|
||||
from ._creation import make_tensor as make_tensor
|
||||
|
@ -9,6 +9,7 @@ from typing import cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_INTEGRAL_TYPES = [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
|
@ -8,7 +8,6 @@ from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.distributed._sharded_tensor import ShardedTensor
|
||||
from torch.distributed._state_dict_utils import _gather_state_dict
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# NOTE: [dynamo_test_failures.py]
|
||||
#
|
||||
# We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures`
|
||||
|
@ -34,9 +34,9 @@ from torch.testing._internal.common_utils import (
|
||||
TrackedInputIter,
|
||||
)
|
||||
from torch.testing._internal.opinfo import utils
|
||||
|
||||
from torchgen.utils import dataclass_repr
|
||||
|
||||
|
||||
# Reasonable testing sizes for dimensions
|
||||
L = 20
|
||||
M = 10
|
||||
|
@ -11,6 +11,7 @@ from torch.testing._internal.opinfo.definitions import (
|
||||
special,
|
||||
)
|
||||
|
||||
|
||||
# Operator database
|
||||
op_db: List[OpInfo] = [
|
||||
*fft.op_db,
|
||||
|
@ -7,7 +7,6 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import SM53OrLater
|
||||
from torch.testing._internal.common_device_type import precisionOverride
|
||||
@ -31,6 +30,7 @@ from torch.testing._internal.opinfo.refs import (
|
||||
PythonRefInfo,
|
||||
)
|
||||
|
||||
|
||||
has_scipy_fft = False
|
||||
if TEST_SCIPY:
|
||||
try:
|
||||
|
@ -11,7 +11,6 @@ import numpy as np
|
||||
from numpy import inf
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_magma_version,
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
from itertools import product
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
@ -18,6 +17,7 @@ from torch.testing._internal.opinfo.core import (
|
||||
SampleInput,
|
||||
)
|
||||
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy.signal
|
||||
|
||||
|
@ -7,6 +7,7 @@ from torch.testing._internal.opinfo.core import (
|
||||
UnaryUfuncInfo,
|
||||
)
|
||||
|
||||
|
||||
# NOTE [Python References]
|
||||
# Python References emulate existing PyTorch operations, but can ultimately
|
||||
# be expressed in terms of "primitive" operations from torch._prims.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user