diff --git a/tools/linter/adapters/ufmt_linter.py b/tools/linter/adapters/ufmt_linter.py index a58c656b3d73..27ad2e001c36 100644 --- a/tools/linter/adapters/ufmt_linter.py +++ b/tools/linter/adapters/ufmt_linter.py @@ -56,7 +56,6 @@ ISORT_SKIPLIST = re.compile( # torch/[e-n]*/** "torch/[e-n]*/**", # torch/[o-z]*/** - "torch/[o-z]*/**", ], ), ) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 212dd5968399..6d7d5b0ee8b2 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -3,7 +3,29 @@ from torch import _C from torch._C import _onnx as _C_onnx from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode -from . import ( # usort:skip. Keep the order instead of sorting lexicographically +from ._exporter_states import ExportTypes +from ._internal.onnxruntime import ( + is_onnxrt_backend_supported, + OrtBackend as _OrtBackend, + OrtBackendOptions as _OrtBackendOptions, + OrtExecutionProvider as _OrtExecutionProvider, +) +from ._type_utils import JitScalarType +from .errors import CheckerError # Backwards compatibility +from .utils import ( + _optimize_graph, + _run_symbolic_function, + _run_symbolic_method, + export, + export_to_pretty_string, + is_in_onnx_export, + register_custom_op_symbolic, + select_model_mode_for_export, + unregister_custom_op_symbolic, +) + + +from . import ( # usort: skip. Keep the order instead of sorting lexicographically _deprecation, errors, symbolic_caffe2, @@ -25,22 +47,8 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical utils, ) -from ._exporter_states import ExportTypes -from ._type_utils import JitScalarType -from .errors import CheckerError # Backwards compatibility -from .utils import ( - _optimize_graph, - _run_symbolic_function, - _run_symbolic_method, - export, - export_to_pretty_string, - is_in_onnx_export, - register_custom_op_symbolic, - select_model_mode_for_export, - unregister_custom_op_symbolic, -) -from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import +from ._internal.exporter import ( # usort: skip. needs to be last to avoid circular import DiagnosticOptions, ExportOptions, ONNXProgram, @@ -53,12 +61,6 @@ from ._internal.exporter import ( # usort:skip. needs to be last to avoid circu enable_fake_mode, ) -from ._internal.onnxruntime import ( - is_onnxrt_backend_supported, - OrtBackend as _OrtBackend, - OrtBackendOptions as _OrtBackendOptions, - OrtExecutionProvider as _OrtExecutionProvider, -) __all__ = [ # Modules diff --git a/torch/onnx/_deprecation.py b/torch/onnx/_deprecation.py index c2aa295c4d4d..24fe4ccc54fc 100644 --- a/torch/onnx/_deprecation.py +++ b/torch/onnx/_deprecation.py @@ -6,6 +6,7 @@ import warnings from typing import Callable, TypeVar from typing_extensions import ParamSpec + _T = TypeVar("_T") _P = ParamSpec("_P") diff --git a/torch/onnx/_internal/diagnostics/__init__.py b/torch/onnx/_internal/diagnostics/__init__.py index cae5b247d5cd..a3eab565a20e 100644 --- a/torch/onnx/_internal/diagnostics/__init__.py +++ b/torch/onnx/_internal/diagnostics/__init__.py @@ -9,6 +9,7 @@ from ._diagnostic import ( from ._rules import rules from .infra import levels + __all__ = [ "TorchScriptOnnxExportDiagnostic", "ExportDiagnosticEngine", diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index 5fbd9719bed3..6100033d6bae 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -7,12 +7,12 @@ import gzip from typing import TYPE_CHECKING import torch - from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics.infra import formatter, sarif from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version from torch.utils import cpp_backtrace + if TYPE_CHECKING: from collections.abc import Generator diff --git a/torch/onnx/_internal/diagnostics/_rules.py b/torch/onnx/_internal/diagnostics/_rules.py index 3b2ca727d0d1..1d353d2e88b3 100644 --- a/torch/onnx/_internal/diagnostics/_rules.py +++ b/torch/onnx/_internal/diagnostics/_rules.py @@ -13,6 +13,7 @@ from typing import Tuple # flake8: noqa from torch.onnx._internal.diagnostics import infra + """ GENERATED CODE - DO NOT EDIT DIRECTLY The purpose of generating a class for each rule is to override the `format_message` diff --git a/torch/onnx/_internal/diagnostics/infra/__init__.py b/torch/onnx/_internal/diagnostics/infra/__init__.py index 6eb6bb444dff..ddcd4891643a 100644 --- a/torch/onnx/_internal/diagnostics/infra/__init__.py +++ b/torch/onnx/_internal/diagnostics/infra/__init__.py @@ -14,6 +14,7 @@ from ._infra import ( ) from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic + __all__ = [ "Diagnostic", "DiagnosticContext", diff --git a/torch/onnx/_internal/diagnostics/infra/context.py b/torch/onnx/_internal/diagnostics/infra/context.py index 33c1f7efb19b..36ae4207b427 100644 --- a/torch/onnx/_internal/diagnostics/infra/context.py +++ b/torch/onnx/_internal/diagnostics/infra/context.py @@ -4,14 +4,10 @@ from __future__ import annotations import contextlib - import dataclasses import gzip - import logging - from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar - from typing_extensions import Self from torch.onnx._internal.diagnostics import infra diff --git a/torch/onnx/_internal/diagnostics/infra/formatter.py b/torch/onnx/_internal/diagnostics/infra/formatter.py index 72155c40843a..5abb698dafce 100644 --- a/torch/onnx/_internal/diagnostics/infra/formatter.py +++ b/torch/onnx/_internal/diagnostics/infra/formatter.py @@ -7,7 +7,6 @@ import traceback from typing import Any, Callable, Union from torch._logging import LazyString - from torch.onnx._internal.diagnostics.infra import sarif diff --git a/torch/onnx/_internal/diagnostics/infra/sarif/__init__.py b/torch/onnx/_internal/diagnostics/infra/sarif/__init__.py index 34fd40e5b938..a01b2abeef9b 100644 --- a/torch/onnx/_internal/diagnostics/infra/sarif/__init__.py +++ b/torch/onnx/_internal/diagnostics/infra/sarif/__init__.py @@ -97,4 +97,5 @@ from torch.onnx._internal.diagnostics.infra.sarif._version_control_details impor from torch.onnx._internal.diagnostics.infra.sarif._web_request import WebRequest from torch.onnx._internal.diagnostics.infra.sarif._web_response import WebResponse + # flake8: noqa diff --git a/torch/onnx/_internal/diagnostics/infra/sarif/version.py b/torch/onnx/_internal/diagnostics/infra/sarif/version.py index 2beddcb3f042..a9b2c9d8fa07 100644 --- a/torch/onnx/_internal/diagnostics/infra/sarif/version.py +++ b/torch/onnx/_internal/diagnostics/infra/sarif/version.py @@ -1,5 +1,6 @@ from typing import Final + SARIF_VERSION: Final = "2.1.0" SARIF_SCHEMA_LINK: Final = "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json" # flake8: noqa diff --git a/torch/onnx/_internal/diagnostics/infra/utils.py b/torch/onnx/_internal/diagnostics/infra/utils.py index f3aa38ee1009..a5d49c38968f 100644 --- a/torch/onnx/_internal/diagnostics/infra/utils.py +++ b/torch/onnx/_internal/diagnostics/infra/utils.py @@ -1,7 +1,6 @@ from __future__ import annotations import functools - import inspect import traceback from typing import Any, Callable, Mapping, Sequence diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index 7c7203c80851..135319ab76f0 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -4,12 +4,10 @@ from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (O ) import abc - import contextlib import dataclasses import logging import os - import tempfile import warnings from collections import defaultdict @@ -27,11 +25,9 @@ from typing import ( from typing_extensions import Self import torch - import torch._ops import torch.export as torch_export import torch.utils._pytree as pytree - from torch.onnx._internal import io_adapter from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.fx import ( @@ -41,6 +37,7 @@ from torch.onnx._internal.fx import ( serialization as fx_serialization, ) + # We can only import onnx from this module in a type-checking context to ensure that # 'import torch.onnx' continues to work without having 'onnx' installed. We fully # 'import onnx' inside of dynamo_export (by way of _assert_dependencies). @@ -48,6 +45,7 @@ if TYPE_CHECKING: import io import onnx + import onnxruntime # type: ignore[import] import onnxscript # type: ignore[import] from onnxscript.function_libs.torch_lib import ( # type: ignore[import] @@ -55,7 +53,6 @@ if TYPE_CHECKING: ) from torch._subclasses import fake_tensor - from torch.onnx._internal.fx import diagnostics _DEFAULT_OPSET_VERSION: Final[int] = 18 diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index ff527afdb556..870b4549d30c 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -2,23 +2,20 @@ from __future__ import annotations import abc - import contextlib import dataclasses import difflib - import io import logging import sys - from typing import Any, Callable, TYPE_CHECKING import torch import torch.fx from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode - from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher + if TYPE_CHECKING: from torch._subclasses import fake_tensor diff --git a/torch/onnx/_internal/fx/analysis/__init__.py b/torch/onnx/_internal/fx/analysis/__init__.py index 406a8d4ac9e2..4440cf66939b 100644 --- a/torch/onnx/_internal/fx/analysis/__init__.py +++ b/torch/onnx/_internal/fx/analysis/__init__.py @@ -1,5 +1,6 @@ from .unsupported_nodes import UnsupportedFxNodesAnalysis + __all__ = [ "UnsupportedFxNodesAnalysis", ] diff --git a/torch/onnx/_internal/fx/decomposition_skip.py b/torch/onnx/_internal/fx/decomposition_skip.py index c3230c1d64a0..aae663906cd5 100644 --- a/torch/onnx/_internal/fx/decomposition_skip.py +++ b/torch/onnx/_internal/fx/decomposition_skip.py @@ -15,7 +15,6 @@ from __future__ import annotations import abc import contextlib - from typing import Callable, Sequence from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found] @@ -26,6 +25,7 @@ from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not- import torch from torch._decomp import decompositions + _NEW_OP_NAMESPACE: str = "onnx_export" """The namespace for the custom operator.""" diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py index 43e384c3a2dc..6ca128da11e4 100644 --- a/torch/onnx/_internal/fx/decomposition_table.py +++ b/torch/onnx/_internal/fx/decomposition_table.py @@ -8,7 +8,6 @@ from typing import Callable import torch import torch._ops import torch.fx - from torch.onnx._internal.fx import registration diff --git a/torch/onnx/_internal/fx/diagnostics.py b/torch/onnx/_internal/fx/diagnostics.py index fd3668fddaee..8617afa5f440 100644 --- a/torch/onnx/_internal/fx/diagnostics.py +++ b/torch/onnx/_internal/fx/diagnostics.py @@ -2,9 +2,7 @@ from __future__ import annotations import dataclasses - import functools - from typing import Any, TYPE_CHECKING import onnxscript # type: ignore[import] @@ -17,6 +15,7 @@ from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics.infra import decorator, formatter from torch.onnx._internal.fx import registration, type_utils as fx_type_utils + if TYPE_CHECKING: import logging diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index 1fd4e6fb4043..8247ce338466 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -16,7 +16,6 @@ from onnxscript.function_libs.torch_lib import ( # type: ignore[import] import torch import torch.fx from torch.onnx import _type_utils as jit_type_utils - from torch.onnx._internal.fx import ( _pass, diagnostics, diff --git a/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py b/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py index 3bce25bb3ea3..ec05a718fb3c 100644 --- a/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py +++ b/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py @@ -2,16 +2,15 @@ from __future__ import annotations import functools - from typing import Any, Callable, Mapping, Sequence import torch import torch.fx import torch.onnx - import torch.onnx._internal.fx.passes as passes from torch.onnx._internal import exporter, io_adapter + # Functions directly wrapped to produce torch.fx.Proxy so that symbolic # data can flow through those functions. Python functions (e.g., `torch.arange`) # not defined by pybind11 in C++ do not go though Python dispatcher, so diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py index 23a6ef7c7416..23dca4227b55 100644 --- a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py +++ b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py @@ -11,13 +11,13 @@ from typing import Any, Callable, Sequence, TYPE_CHECKING import torch import torch._ops import torch.fx - from torch.onnx._internal.fx import ( diagnostics, registration, type_utils as fx_type_utils, ) + if TYPE_CHECKING: import onnxscript # type: ignore[import] from onnxscript.function_libs.torch_lib import ( # type: ignore[import] diff --git a/torch/onnx/_internal/fx/op_validation.py b/torch/onnx/_internal/fx/op_validation.py index 10cb8e923584..15b2aa12e71f 100644 --- a/torch/onnx/_internal/fx/op_validation.py +++ b/torch/onnx/_internal/fx/op_validation.py @@ -13,7 +13,6 @@ import torch import torch.fx from torch.fx.experimental import symbolic_shapes from torch.onnx import _constants, _type_utils as jit_type_utils - from torch.onnx._internal.fx import ( diagnostics, fx_onnx_interpreter, diff --git a/torch/onnx/_internal/fx/passes/__init__.py b/torch/onnx/_internal/fx/passes/__init__.py index 7f9cdfd16cfa..aa04e6beb5f1 100644 --- a/torch/onnx/_internal/fx/passes/__init__.py +++ b/torch/onnx/_internal/fx/passes/__init__.py @@ -5,6 +5,7 @@ from .readability import RestoreParameterAndBufferNames from .type_promotion import InsertTypePromotion from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder + __all__ = [ "Decompose", "InsertTypePromotion", diff --git a/torch/onnx/_internal/fx/passes/_utils.py b/torch/onnx/_internal/fx/passes/_utils.py index 60326c17315b..fc35e24b1bee 100644 --- a/torch/onnx/_internal/fx/passes/_utils.py +++ b/torch/onnx/_internal/fx/passes/_utils.py @@ -6,9 +6,7 @@ These functions should NOT be directly invoked outside of `passes` package. from __future__ import annotations import collections - import re - from typing import Callable import torch.fx diff --git a/torch/onnx/_internal/fx/passes/decomp.py b/torch/onnx/_internal/fx/passes/decomp.py index df8ce6a07c21..e0869e78d100 100644 --- a/torch/onnx/_internal/fx/passes/decomp.py +++ b/torch/onnx/_internal/fx/passes/decomp.py @@ -2,7 +2,6 @@ from __future__ import annotations import contextlib - from typing import Callable, Mapping, TYPE_CHECKING import torch @@ -12,6 +11,7 @@ from torch.fx.experimental import proxy_tensor from torch.onnx._internal.fx import _pass, diagnostics from torch.onnx._internal.fx.passes import _utils + if TYPE_CHECKING: import torch.fx from torch._subclasses import fake_tensor diff --git a/torch/onnx/_internal/fx/passes/functionalization.py b/torch/onnx/_internal/fx/passes/functionalization.py index d4eddd19654a..e3e5b36c0a15 100644 --- a/torch/onnx/_internal/fx/passes/functionalization.py +++ b/torch/onnx/_internal/fx/passes/functionalization.py @@ -2,7 +2,6 @@ from __future__ import annotations import contextlib - from typing import Callable, TYPE_CHECKING import torch @@ -14,6 +13,7 @@ from torch.onnx._internal.fx import _pass, diagnostics from torch.onnx._internal.fx.passes import _utils from torch.utils import _pytree as pytree + if TYPE_CHECKING: from torch._subclasses import fake_tensor diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py index db74d52dda47..bcd3e956b425 100644 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -2,19 +2,17 @@ from __future__ import annotations import abc - import collections import copy import operator - from typing import Any, Dict, Final, Generator, Iterator, Sequence, Tuple import torch import torch.fx - from torch.onnx._internal.fx import _pass, diagnostics from torch.utils import _pytree as pytree + _FX_TRACER_NN_MODULE_META_TYPE = Tuple[str, type] """Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer.""" _FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict diff --git a/torch/onnx/_internal/fx/passes/readability.py b/torch/onnx/_internal/fx/passes/readability.py index d91d40a88e5c..50221f47f64f 100644 --- a/torch/onnx/_internal/fx/passes/readability.py +++ b/torch/onnx/_internal/fx/passes/readability.py @@ -4,7 +4,6 @@ from __future__ import annotations from typing import Sequence import torch - from torch.onnx._internal.fx import _pass, diagnostics diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index ca5e4ff4381f..ee5a55cacf81 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -3,20 +3,16 @@ from __future__ import annotations import abc - import dataclasses import inspect import logging - from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING import torch import torch._ops import torch.fx import torch.fx.traceback as fx_traceback - from torch import _prims_common, _refs - from torch._prims_common import ( ELEMENTWISE_TYPE_PROMOTION_KIND, wrappers as _prims_common_wrappers, @@ -24,15 +20,16 @@ from torch._prims_common import ( from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs from torch._refs.nn import functional as _functional_refs from torch.fx.experimental import proxy_tensor - from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils from torch.utils import _python_dispatch, _pytree + if TYPE_CHECKING: from types import ModuleType from torch._subclasses import fake_tensor + logger = logging.getLogger(__name__) # TODO(bowbao): move to type utils. diff --git a/torch/onnx/_internal/fx/passes/virtualization.py b/torch/onnx/_internal/fx/passes/virtualization.py index b94fadec8305..456c25fee777 100644 --- a/torch/onnx/_internal/fx/passes/virtualization.py +++ b/torch/onnx/_internal/fx/passes/virtualization.py @@ -4,9 +4,9 @@ from __future__ import annotations from typing import TYPE_CHECKING import torch - from torch.onnx._internal.fx import _pass + if TYPE_CHECKING: import torch.fx diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py index 239edb6dde63..04298159deb0 100644 --- a/torch/onnx/_internal/fx/patcher.py +++ b/torch/onnx/_internal/fx/patcher.py @@ -5,6 +5,7 @@ from typing import List, TYPE_CHECKING, Union import torch + if TYPE_CHECKING: import io @@ -16,7 +17,6 @@ def has_safetensors_and_transformers(): # safetensors is not an exporter requirement, but needed for some huggingface models import safetensors # type: ignore[import] # noqa: F401 import transformers # type: ignore[import] # noqa: F401 - from safetensors import torch as safetensors_torch # noqa: F401 return True diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index 775b49ef2a1a..a5f3f3d7b07c 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -13,6 +13,7 @@ import torch.fx from torch.onnx._internal import exporter, io_adapter from torch.onnx._internal.diagnostics import infra + if TYPE_CHECKING: import torch.onnx from torch.export.exported_program import ExportedProgram diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index d32fbd44cd6b..8aba18d5918a 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -15,11 +15,13 @@ from typing import ( ) import numpy + import onnx import torch from torch._subclasses import fake_tensor + if TYPE_CHECKING: import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004 diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index e2f322809b7d..fb89cad85ac5 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -13,9 +13,9 @@ from typing import ( import torch import torch.export as torch_export - from torch.utils import _pytree as pytree + if TYPE_CHECKING: import inspect diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py index de72922ef44a..ec5dc4b96200 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/jit_utils.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs """Utilities for manipulating the torch.Graph object and the torchscript.""" -from __future__ import annotations # TODO(justinchuby): Move more of the symbolic helper functions here and expose # them to the user. +from __future__ import annotations + import dataclasses import re import typing diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py index 59609866bae8..3cc5aa3167e5 100644 --- a/torch/onnx/_internal/onnxruntime.py +++ b/torch/onnx/_internal/onnxruntime.py @@ -4,7 +4,6 @@ import dataclasses import importlib import logging import os - from typing import ( Any, Callable, @@ -33,6 +32,7 @@ from torch.fx.passes.operator_support import OperatorSupport from torch.fx.passes.tools_common import CALLABLE_NODE_OPS from torch.utils import _pytree + if TYPE_CHECKING: import onnx @@ -929,7 +929,7 @@ class OrtBackend: try: from onnxscript import optimizer # type: ignore[import] from onnxscript.rewriter import ( # type: ignore[import] - onnxruntime as ort_rewriter, # type: ignore[import] + onnxruntime as ort_rewriter, ) onnx_model = optimizer.optimize(onnx_model) @@ -1112,7 +1112,6 @@ class OrtBackend: the ``compile`` method is invoked directly.""" if self._options.use_aot_autograd: from functorch.compile import min_cut_rematerialization_partition - from torch._dynamo.backends.common import aot_autograd return aot_autograd( diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py index 1099130c7f17..1d80dc89ea71 100644 --- a/torch/onnx/_type_utils.py +++ b/torch/onnx/_type_utils.py @@ -10,6 +10,7 @@ import torch from torch._C import _onnx as _C_onnx from torch.onnx import errors + if typing.TYPE_CHECKING: # Hack to help mypy to recognize torch._C.Value from torch import _C # noqa: F401 diff --git a/torch/onnx/errors.py b/torch/onnx/errors.py index 483ca22097fd..456d05e7593e 100644 --- a/torch/onnx/errors.py +++ b/torch/onnx/errors.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from torch.onnx import _constants from torch.onnx._internal import diagnostics + if TYPE_CHECKING: from torch import _C diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index eb853d39b31f..6c8556248742 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -19,6 +19,7 @@ from torch.onnx import _constants, _type_utils, errors, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import jit_utils + if typing.TYPE_CHECKING: from torch.types import Number diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index c30452185c55..975b6bdbe7d8 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -23,6 +23,7 @@ from torch.onnx import ( from torch.onnx._globals import GLOBALS from torch.onnx._internal import jit_utils, registration + # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index eafd440639e4..22f3d4d6de1d 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -21,6 +21,7 @@ from torch.onnx import ( ) from torch.onnx._internal import jit_utils, registration + # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index b7bf47192a2e..2e2acec72f97 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -25,6 +25,7 @@ from torch.onnx import _constants, _type_utils, symbolic_helper from torch.onnx._globals import GLOBALS from torch.onnx._internal import jit_utils, registration + __all__ = [ "hardswish", "tril", diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index fd876d3c00f6..08f8dcbf5a22 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -33,6 +33,7 @@ from torch import _C from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 from torch.onnx._internal import jit_utils, registration + _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index 696f1fecb022..3c3af5f55ac9 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -36,6 +36,7 @@ from torch.nn.functional import ( from torch.onnx import _type_utils, errors, symbolic_helper, utils from torch.onnx._internal import jit_utils, registration + _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index 884dcf6f4bbb..0aca6634d2f6 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -25,6 +25,7 @@ from torch import _C from torch.onnx import _type_utils, errors, symbolic_helper from torch.onnx._internal import jit_utils, registration + # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md diff --git a/torch/onnx/symbolic_opset18.py b/torch/onnx/symbolic_opset18.py index ee441b59ff0c..d28fadc1bab1 100644 --- a/torch/onnx/symbolic_opset18.py +++ b/torch/onnx/symbolic_opset18.py @@ -27,6 +27,7 @@ from torch import _C from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 from torch.onnx._internal import jit_utils, registration + # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py diff --git a/torch/onnx/symbolic_opset19.py b/torch/onnx/symbolic_opset19.py index 81b69dd3cc20..a97dda26f81f 100644 --- a/torch/onnx/symbolic_opset19.py +++ b/torch/onnx/symbolic_opset19.py @@ -26,6 +26,7 @@ Size from typing import List + # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py diff --git a/torch/onnx/symbolic_opset20.py b/torch/onnx/symbolic_opset20.py index fe4fa1e62315..ed1997bc803c 100644 --- a/torch/onnx/symbolic_opset20.py +++ b/torch/onnx/symbolic_opset20.py @@ -24,11 +24,11 @@ New operators: import functools import torch.nn.functional as F - from torch import _C from torch.onnx import symbolic_helper from torch.onnx._internal import jit_utils, registration + # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 191a3259efb1..41abf46be2a0 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -39,6 +39,7 @@ from torch._C import _onnx as _C_onnx from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 from torch.onnx._internal import jit_utils, registration + _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) block_listed_operators = ( diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 5b3acb28348f..a5f41c231272 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -27,6 +27,7 @@ from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_h from torch.onnx._globals import GLOBALS from torch.onnx._internal import jit_utils, registration + if TYPE_CHECKING: from torch.types import Number diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index bcf1de6b6437..195e7a822851 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -30,6 +30,7 @@ from torch.onnx._globals import GLOBALS from torch.onnx._internal import onnx_proto_utils from torch.types import Number + _ORT_PROVIDERS = ("CPUExecutionProvider",) _NumericType = Union[Number, torch.Tensor, np.ndarray] diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 285dd73276a1..ee53d1c5fd76 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -23,6 +23,7 @@ from torch.optim.rprop import Rprop from torch.optim.sgd import SGD from torch.optim.sparse_adam import SparseAdam + Adafactor.__module__ = "torch.optim" diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index b1f6ada45805..73ba8658750d 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union import torch from torch import Tensor + from .optimizer import ( _disable_dynamo_if_unsupported, _get_scalar_dtype, @@ -12,6 +13,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["Adafactor", "adafactor"] diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index fb719b3a0d8b..d1a05d6df70b 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -20,6 +20,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["Adadelta", "adadelta"] diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 016389a4a6c0..ba8a1c895a38 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union import torch from torch import Tensor from torch.utils._foreach_utils import _get_fused_kernels_supported_devices + from .optimizer import ( _default_to_fused_or_foreach, _differentiable_doc, @@ -17,6 +18,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["Adagrad", "adagrad"] diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 661083996668..97648b86bec1 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union import torch from torch import Tensor from torch.utils._foreach_utils import _get_fused_kernels_supported_devices + from .optimizer import ( _capturable_doc, _default_to_fused_or_foreach, @@ -24,6 +25,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["Adam", "adam"] diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index b74c9e011680..7cb5e464f5a6 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -21,6 +21,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["Adamax", "adamax"] diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 0687d13c3dfe..345b4369050c 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union import torch from torch import Tensor from torch.utils._foreach_utils import _get_fused_kernels_supported_devices + from .optimizer import ( _capturable_doc, _default_to_fused_or_foreach, @@ -24,6 +25,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["AdamW", "adamw"] diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 163dbc0c9720..1d8402edc48e 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -21,6 +21,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["ASGD", "asgd"] diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 7178292ff58e..f9c2e13077e3 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -3,8 +3,10 @@ from typing import Optional, Union import torch from torch import Tensor + from .optimizer import Optimizer, ParamsT + __all__ = ["LBFGS"] diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index f9494d0f36b6..a77689a44d6c 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -26,6 +26,7 @@ from torch import inf, Tensor from .optimizer import Optimizer + __all__ = [ "LambdaLR", "MultiplicativeLR", diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 6ec27df939c3..54cc8df5a9b7 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union import torch from torch import Tensor + from .optimizer import ( _capturable_doc, _default_to_fused_or_foreach, @@ -22,6 +23,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["NAdam", "nadam"] diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 6cbc18b06282..956c4167f42e 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -35,6 +35,7 @@ from torch.utils._foreach_utils import ( ) from torch.utils.hooks import RemovableHandle + Args: TypeAlias = Tuple[Any, ...] Kwargs: TypeAlias = Dict[str, Any] StateDict: TypeAlias = Dict[str, Any] diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 6ec7ca221fdc..24949ea4e05d 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -22,6 +22,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["RAdam", "radam"] diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 860d1f6aa7d9..c9b33684f48e 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -5,6 +5,7 @@ from typing import List, Optional, Union import torch from torch import Tensor + from .optimizer import ( _capturable_doc, _default_to_fused_or_foreach, @@ -20,6 +21,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["RMSprop", "rmsprop"] diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 395479bcac89..ba0be649a8fc 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union import torch from torch import Tensor + from .optimizer import ( _capturable_doc, _default_to_fused_or_foreach, @@ -20,6 +21,7 @@ from .optimizer import ( ParamsT, ) + __all__ = ["Rprop", "rprop"] diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index bd1283dc687a..c9b2b169b1a7 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -5,6 +5,7 @@ from typing import List, Optional, Union import torch from torch import Tensor from torch.utils._foreach_utils import _get_fused_kernels_supported_devices + from .optimizer import ( _default_to_fused_or_foreach, _differentiable_doc, @@ -16,6 +17,7 @@ from .optimizer import ( Optimizer, ) + __all__ = ["SGD", "sgd"] diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index 5f2069a781cd..22ef7841270f 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -3,9 +3,11 @@ from typing import List, Tuple, Union import torch from torch import Tensor + from . import _functional as F from .optimizer import _maximize_doc, Optimizer, ParamsT + __all__ = ["SparseAdam"] diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index a17f387286da..73c71eddecac 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -11,8 +11,10 @@ from torch import Tensor from torch.nn import Module from torch.optim.lr_scheduler import _format_param, LRScheduler from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices + from .optimizer import Optimizer + __all__ = [ "AveragedModel", "update_bn", @@ -25,6 +27,7 @@ __all__ = [ from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype + PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]] diff --git a/torch/package/_directory_reader.py b/torch/package/_directory_reader.py index 77d629cccce2..f58065f47dc4 100644 --- a/torch/package/_directory_reader.py +++ b/torch/package/_directory_reader.py @@ -6,6 +6,7 @@ from typing import cast import torch from torch.types import Storage + __serialization_id_record_name__ = ".data/serialization_id" diff --git a/torch/package/_importlib.py b/torch/package/_importlib.py index 9741925315e5..609efd294c4c 100644 --- a/torch/package/_importlib.py +++ b/torch/package/_importlib.py @@ -2,6 +2,7 @@ import _warnings import os.path + # note: implementations # copied from cpython's import code diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 700a9ad6a04a..0cf3791d1604 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -4,6 +4,7 @@ See mangling.md for details. """ import re + _mangle_index = 0 diff --git a/torch/package/analyze/find_first_use_of_broken_modules.py b/torch/package/analyze/find_first_use_of_broken_modules.py index 1910afdd98e3..b3016a56c2a4 100644 --- a/torch/package/analyze/find_first_use_of_broken_modules.py +++ b/torch/package/analyze/find_first_use_of_broken_modules.py @@ -1,6 +1,7 @@ from typing import Dict, List -from ..package_exporter import PackagingError +from torch.package.package_exporter import PackagingError + __all__ = ["find_first_use_of_broken_modules"] diff --git a/torch/package/analyze/trace_dependencies.py b/torch/package/analyze/trace_dependencies.py index 405fcf2f9bc2..23f6c998385b 100644 --- a/torch/package/analyze/trace_dependencies.py +++ b/torch/package/analyze/trace_dependencies.py @@ -2,6 +2,7 @@ import sys from typing import Any, Callable, Iterable, List, Tuple + __all__ = ["trace_dependencies"] diff --git a/torch/package/file_structure_representation.py b/torch/package/file_structure_representation.py index 44e07978640f..e1137234ab73 100644 --- a/torch/package/file_structure_representation.py +++ b/torch/package/file_structure_representation.py @@ -3,6 +3,7 @@ from typing import Dict, List from .glob_group import GlobGroup, GlobPattern + __all__ = ["Directory"] diff --git a/torch/package/glob_group.py b/torch/package/glob_group.py index 974364400502..1c1d31930fd1 100644 --- a/torch/package/glob_group.py +++ b/torch/package/glob_group.py @@ -2,6 +2,7 @@ import re from typing import Iterable, Union + GlobPattern = Union[str, Iterable[str]] diff --git a/torch/package/importer.py b/torch/package/importer.py index 513847513910..4983d2f6995d 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import importlib from abc import ABC, abstractmethod -from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined] +from pickle import ( # type: ignore[attr-defined] _getattribute, _Pickler, whichmodule as _pickle_whichmodule, @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple from ._mangling import demangle, get_mangle_prefix, is_mangled + __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index bfa00278fa4b..7f1f8cc2f29d 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -39,6 +39,7 @@ from .find_file_dependencies import find_files_source_depends_on from .glob_group import GlobGroup, GlobPattern from .importer import Importer, OrderedImporter, sys_importer + __all__ = [ "PackagingErrorReason", "EmptyMatchError", diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 1a103ab6c5c9..3eb06277f1da 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -38,6 +38,7 @@ from ._package_unpickler import PackageUnpickler from .file_structure_representation import _create_directory_from_file_list, Directory from .importer import Importer + if TYPE_CHECKING: from .glob_group import GlobPattern diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index 4a681daf788e..073096607afe 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -25,6 +25,7 @@ from .profiler import ( tensorboard_trace_handler, ) + __all__ = [ "profile", "schedule", diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 1834f0494e02..2095b882f5de 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -32,6 +32,7 @@ from torch._C._profiler import ( from torch._utils import _element_size from torch.profiler import _utils + KeyAndID = Tuple["Key", int] TensorAndID = Tuple["TensorKey", int] diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index d69fa4630595..20dfeb80adeb 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -7,9 +7,9 @@ from dataclasses import dataclass from typing import Dict, List, TYPE_CHECKING from torch.autograd.profiler import profile - from torch.profiler import DeviceType + if TYPE_CHECKING: from torch.autograd import _KinetoEvent diff --git a/torch/profiler/itt.py b/torch/profiler/itt.py index 4666bba515a3..9d4bda2b3420 100644 --- a/torch/profiler/itt.py +++ b/torch/profiler/itt.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs from contextlib import contextmanager + try: from torch._C import _itt except ImportError: diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index a82518db6084..8789fea17a17 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -1,16 +1,14 @@ # mypy: allow-untyped-defs -from .quantize import * # noqa: F403 -from .observer import * # noqa: F403 -from .qconfig import * # noqa: F403 from .fake_quantize import * # noqa: F403 from .fuse_modules import fuse_modules -from .stubs import * # noqa: F403 -from .quant_type import * # noqa: F403 -from .quantize_jit import * # noqa: F403 - -# from .quantize_fx import * -from .quantization_mappings import * # noqa: F403 from .fuser_method_mappings import * # noqa: F403 +from .observer import * # noqa: F403 +from .qconfig import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantization_mappings import * # noqa: F403 +from .quantize import * # noqa: F403 +from .quantize_jit import * # noqa: F403 +from .stubs import * # noqa: F403 def default_eval_fn(model, calib_data): diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index 26954833bb48..2a83e180fc4d 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -15,6 +15,7 @@ from torch.ao.quantization.fx.pattern_utils import ( QuantizeHandler, ) + # QuantizeHandler.__module__ = _NAMESPACE _register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 34ee88a4713c..20d8cc52ee4f 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -23,6 +23,7 @@ from torch.ao.quantization.fx.quantize_handler import ( StandaloneModuleQuantizeHandler, ) + QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 8b3a1f2e2b2d..b99061cd0cbd 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -15,6 +15,7 @@ from .semi_structured import ( to_sparse_semi_structured, ) + if TYPE_CHECKING: from torch.types import _dtype as DType diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index bcaa889ba1ee..d7c76d8d8be8 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -3,6 +3,7 @@ import contextlib import torch + __all__ = [ "fallback_dispatcher", "semi_sparse_values", diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index 4ab313064ba4..ccd29b9f5a5b 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -8,8 +8,10 @@ from typing import Optional, Tuple import torch from torch.utils._triton import has_triton + from ._triton_ops_meta import get_meta + TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2) ) diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 23193021232a..a4df1306ac19 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -20,6 +20,7 @@ from torch.sparse._semi_structured_ops import ( semi_sparse_view, ) + __all__ = [ "SparseSemiStructuredTensor", "SparseSemiStructuredTensorCUTLASS", diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 352ce67e074a..de042277c7c8 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -1,4 +1,5 @@ from torch._C import FileCheck as FileCheck + from . import _utils from ._comparison import assert_allclose, assert_close as assert_close from ._creation import make_tensor as make_tensor diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index d8fb2ef18b1d..9de3bd09882e 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -9,6 +9,7 @@ from typing import cast, List, Optional, Tuple, Union import torch + _INTEGRAL_TYPES = [ torch.uint8, torch.int8, diff --git a/torch/testing/_internal/distributed/common_state_dict.py b/torch/testing/_internal/distributed/common_state_dict.py index 36b8cbf943ed..5aa8d7f14aa0 100644 --- a/torch/testing/_internal/distributed/common_state_dict.py +++ b/torch/testing/_internal/distributed/common_state_dict.py @@ -8,7 +8,6 @@ from typing import Any, Dict import torch import torch.nn as nn - from torch.distributed._sharded_tensor import ShardedTensor from torch.distributed._state_dict_utils import _gather_state_dict from torch.distributed._tensor import DTensor diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py index 1d542d694a55..6deb1b2af13c 100644 --- a/torch/testing/_internal/dynamo_test_failures.py +++ b/torch/testing/_internal/dynamo_test_failures.py @@ -3,6 +3,7 @@ import logging import os import sys + # NOTE: [dynamo_test_failures.py] # # We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures` diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index ecccc4fbd8d5..2aa38511d4e9 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -34,9 +34,9 @@ from torch.testing._internal.common_utils import ( TrackedInputIter, ) from torch.testing._internal.opinfo import utils - from torchgen.utils import dataclass_repr + # Reasonable testing sizes for dimensions L = 20 M = 10 diff --git a/torch/testing/_internal/opinfo/definitions/__init__.py b/torch/testing/_internal/opinfo/definitions/__init__.py index bd2ae805370c..4820a3eae232 100644 --- a/torch/testing/_internal/opinfo/definitions/__init__.py +++ b/torch/testing/_internal/opinfo/definitions/__init__.py @@ -11,6 +11,7 @@ from torch.testing._internal.opinfo.definitions import ( special, ) + # Operator database op_db: List[OpInfo] = [ *fft.op_db, diff --git a/torch/testing/_internal/opinfo/definitions/fft.py b/torch/testing/_internal/opinfo/definitions/fft.py index 65c9d6c08dac..6ed395eef020 100644 --- a/torch/testing/_internal/opinfo/definitions/fft.py +++ b/torch/testing/_internal/opinfo/definitions/fft.py @@ -7,7 +7,6 @@ from typing import List import numpy as np import torch - from torch.testing import make_tensor from torch.testing._internal.common_cuda import SM53OrLater from torch.testing._internal.common_device_type import precisionOverride @@ -31,6 +30,7 @@ from torch.testing._internal.opinfo.refs import ( PythonRefInfo, ) + has_scipy_fft = False if TEST_SCIPY: try: diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 7956f5b95b66..e94c6a671144 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -11,7 +11,6 @@ import numpy as np from numpy import inf import torch - from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( _get_magma_version, diff --git a/torch/testing/_internal/opinfo/definitions/signal.py b/torch/testing/_internal/opinfo/definitions/signal.py index 6f51e388966a..105590a71fb7 100644 --- a/torch/testing/_internal/opinfo/definitions/signal.py +++ b/torch/testing/_internal/opinfo/definitions/signal.py @@ -2,7 +2,6 @@ import unittest from functools import partial - from itertools import product from typing import Callable, List, Tuple @@ -18,6 +17,7 @@ from torch.testing._internal.opinfo.core import ( SampleInput, ) + if TEST_SCIPY: import scipy.signal diff --git a/torch/testing/_internal/opinfo/refs.py b/torch/testing/_internal/opinfo/refs.py index 92bbdf8d6b2e..435a9d113164 100644 --- a/torch/testing/_internal/opinfo/refs.py +++ b/torch/testing/_internal/opinfo/refs.py @@ -7,6 +7,7 @@ from torch.testing._internal.opinfo.core import ( UnaryUfuncInfo, ) + # NOTE [Python References] # Python References emulate existing PyTorch operations, but can ultimately # be expressed in terms of "primitive" operations from torch._prims. diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index 70ee48274800..7fac1e57c6ac 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -13,9 +13,7 @@ import unittest from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch - import torch._dynamo - import torch.utils._pytree as pytree from torch._dynamo.utils import clone_input from torch._library.custom_ops import CustomOpDef diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index ab54e8b8bec9..f9a50d68c512 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -5,6 +5,7 @@ import unittest from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU from torch.utils._triton import has_triton + requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") diff --git a/torch/testing/_utils.py b/torch/testing/_utils.py index 50d077cb1649..5b4b7c3796e3 100644 --- a/torch/testing/_utils.py +++ b/torch/testing/_utils.py @@ -3,6 +3,7 @@ import contextlib import torch + # Common testing utilities for use in public testing APIs. # NB: these should all be importable without optional dependencies # (like numpy and expecttest). diff --git a/torch/utils/_backport_slots.py b/torch/utils/_backport_slots.py index 3e265ce20fad..dcafb32877f3 100644 --- a/torch/utils/_backport_slots.py +++ b/torch/utils/_backport_slots.py @@ -7,10 +7,12 @@ import dataclasses import itertools from typing import Generator, List, Type, TYPE_CHECKING, TypeVar + if TYPE_CHECKING: from _typeshed import DataclassInstance -__all__ = ("dataclass_slots",) + +__all__ = ["dataclass_slots"] _T = TypeVar("_T", bound="DataclassInstance") diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index abff5af82b77..3efc39b2635d 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -1,5 +1,4 @@ import contextlib - import copy import hashlib import inspect @@ -13,6 +12,7 @@ from typing import Any, Callable, Dict, NoReturn, Optional, Set, Union from typing_extensions import deprecated from unittest import mock + # Types saved/loaded in configs CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) diff --git a/torch/utils/_content_store.py b/torch/utils/_content_store.py index dec70d90b7d3..deb3ba9008c0 100644 --- a/torch/utils/_content_store.py +++ b/torch/utils/_content_store.py @@ -41,7 +41,6 @@ import torch._prims as prims import torch._utils import torch.nn.functional as F from torch._C import default_generator - from torch.multiprocessing.reductions import StorageWeakRef diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 753e9ca8542c..7b7cdc8a7e25 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -33,9 +33,9 @@ import optree from optree import PyTreeSpec # direct import for type annotations import torch.utils._pytree as _pytree - from torch.utils._pytree import KeyEntry + __all__ = [ "PyTree", "Context", diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index 22a4159bbb13..cb3d6a686ea3 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -15,6 +15,7 @@ from typing import ( TypeVar, ) + T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) diff --git a/torch/utils/_strobelight/examples/cli_function_profiler_example.py b/torch/utils/_strobelight/examples/cli_function_profiler_example.py index 222a70c9fe2d..d92fa3b8a603 100644 --- a/torch/utils/_strobelight/examples/cli_function_profiler_example.py +++ b/torch/utils/_strobelight/examples/cli_function_profiler_example.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import torch - from torch.utils._strobelight.cli_function_profiler import ( strobelight, StrobelightCLIFunctionProfiler, diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 3d04b35b0b80..58683f8f6e67 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -10,6 +10,7 @@ from sympy.core.numbers import equal_valued from .numbers import int_oo + __all__ = [ "FloorDiv", "ModularIndexing", diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 65170a1881ed..04b99748c444 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -16,6 +16,7 @@ import sympy from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch + from .functions import ( CeilToInt, CleanDiv, diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 642044b43bfa..21ec1feb0ea6 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import math - import operator import sympy diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py index 1b5e8a96104f..0bac76121f8b 100644 --- a/torch/utils/_sympy/singleton_int.py +++ b/torch/utils/_sympy/singleton_int.py @@ -2,6 +2,7 @@ import sympy from sympy.multipledispatch import dispatch + __all__ = ["SingletonInt"] diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 6a5e6efe134b..e122d6cd0b5f 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -1,11 +1,11 @@ import logging - from typing import Dict, Optional, Tuple, Type import sympy from torch.utils._sympy.functions import FloorDiv + log = logging.getLogger(__name__) _MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = { diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 29ee1886261b..57d509323e01 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -24,8 +24,8 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch from torch._logging import LazyString - from torch._prims_common import dtype_to_type + from .functions import ( _keep_float, FloatTrueDiv, @@ -45,6 +45,7 @@ from .functions import ( from .interp import sympy_interp from .numbers import int_oo, IntInfinity, NegativeIntInfinity + log = logging.getLogger(__name__) __all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"] diff --git a/torch/utils/_typing_utils.py b/torch/utils/_typing_utils.py index fd1b6ca5785f..ffb6b383e4e6 100644 --- a/torch/utils/_typing_utils.py +++ b/torch/utils/_typing_utils.py @@ -2,6 +2,7 @@ from typing import Optional, TypeVar + # Helper to turn Optional[T] into T when we know None either isn't # possible or should trigger an exception. T = TypeVar("T") diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index e3e99df77780..f15b70ee66d6 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -11,9 +11,11 @@ import sys # In case of metaclass conflict due to ABCMeta or _ProtocolMeta # For Python 3.9, only Protocol in typing uses metaclass from abc import ABCMeta + +# TODO: Use TypeAlias when Python 3.6 is deprecated from typing import ( # type: ignore[attr-defined] _eval_type, - _GenericAlias, # TODO: Use TypeAlias when Python 3.6 is deprecated + _GenericAlias, _tp_cache, _type_check, _type_repr, diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 01e966c712b5..91958127b03b 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import weakref - from typing import Set import torch @@ -11,6 +10,7 @@ from torch.nn.modules.module import ( ) from torch.utils._pytree import tree_flatten + __all__ = ["ModuleTracker"] diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 6049a11861d2..b6eafc38ff1d 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -13,11 +13,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch._C -from .. import device as _device -from .._utils import _dummy_type, _LazySeedTracker +from torch import device as _device +from torch._utils import _dummy_type, _LazySeedTracker + from ._utils import _get_device_index from .streams import Event, Stream + _initialized = False _tls = threading.local() _initialization_lock = threading.Lock() diff --git a/torch/xpu/random.py b/torch/xpu/random.py index 1ebdd476ed8c..b8631ddc1850 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -2,7 +2,8 @@ from typing import Iterable, List, Union import torch -from .. import Tensor +from torch import Tensor + from . import _lazy_call, _lazy_init, current_device, device_count