[BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771)

See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
This commit is contained in:
Xuehai Pan
2024-07-31 19:56:45 +08:00
committed by PyTorch MergeBot
parent c59f3fff52
commit 30293319a8
120 changed files with 163 additions and 101 deletions

View File

@ -56,7 +56,6 @@ ISORT_SKIPLIST = re.compile(
# torch/[e-n]*/**
"torch/[e-n]*/**",
# torch/[o-z]*/**
"torch/[o-z]*/**",
],
),
)

View File

@ -3,6 +3,28 @@ from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from ._exporter_states import ExportTypes
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
OrtBackend as _OrtBackend,
OrtBackendOptions as _OrtBackendOptions,
OrtExecutionProvider as _OrtExecutionProvider,
)
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
_deprecation,
errors,
@ -25,20 +47,6 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
utils,
)
from ._exporter_states import ExportTypes
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from ._internal.exporter import ( # usort: skip. needs to be last to avoid circular import
DiagnosticOptions,
@ -53,12 +61,6 @@ from ._internal.exporter import ( # usort:skip. needs to be last to avoid circu
enable_fake_mode,
)
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
OrtBackend as _OrtBackend,
OrtBackendOptions as _OrtBackendOptions,
OrtExecutionProvider as _OrtExecutionProvider,
)
__all__ = [
# Modules

View File

@ -6,6 +6,7 @@ import warnings
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
_T = TypeVar("_T")
_P = ParamSpec("_P")

View File

@ -9,6 +9,7 @@ from ._diagnostic import (
from ._rules import rules
from .infra import levels
__all__ = [
"TorchScriptOnnxExportDiagnostic",
"ExportDiagnosticEngine",

View File

@ -7,12 +7,12 @@ import gzip
from typing import TYPE_CHECKING
import torch
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
from torch.utils import cpp_backtrace
if TYPE_CHECKING:
from collections.abc import Generator

View File

@ -13,6 +13,7 @@ from typing import Tuple
# flake8: noqa
from torch.onnx._internal.diagnostics import infra
"""
GENERATED CODE - DO NOT EDIT DIRECTLY
The purpose of generating a class for each rule is to override the `format_message`

View File

@ -14,6 +14,7 @@ from ._infra import (
)
from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic
__all__ = [
"Diagnostic",
"DiagnosticContext",

View File

@ -4,14 +4,10 @@
from __future__ import annotations
import contextlib
import dataclasses
import gzip
import logging
from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
from typing_extensions import Self
from torch.onnx._internal.diagnostics import infra

View File

@ -7,7 +7,6 @@ import traceback
from typing import Any, Callable, Union
from torch._logging import LazyString
from torch.onnx._internal.diagnostics.infra import sarif

View File

@ -97,4 +97,5 @@ from torch.onnx._internal.diagnostics.infra.sarif._version_control_details impor
from torch.onnx._internal.diagnostics.infra.sarif._web_request import WebRequest
from torch.onnx._internal.diagnostics.infra.sarif._web_response import WebResponse
# flake8: noqa

View File

@ -1,5 +1,6 @@
from typing import Final
SARIF_VERSION: Final = "2.1.0"
SARIF_SCHEMA_LINK: Final = "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
# flake8: noqa

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import functools
import inspect
import traceback
from typing import Any, Callable, Mapping, Sequence

View File

@ -4,12 +4,10 @@ from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (O
)
import abc
import contextlib
import dataclasses
import logging
import os
import tempfile
import warnings
from collections import defaultdict
@ -27,11 +25,9 @@ from typing import (
from typing_extensions import Self
import torch
import torch._ops
import torch.export as torch_export
import torch.utils._pytree as pytree
from torch.onnx._internal import io_adapter
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import (
@ -41,6 +37,7 @@ from torch.onnx._internal.fx import (
serialization as fx_serialization,
)
# We can only import onnx from this module in a type-checking context to ensure that
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
@ -48,6 +45,7 @@ if TYPE_CHECKING:
import io
import onnx
import onnxruntime # type: ignore[import]
import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
@ -55,7 +53,6 @@ if TYPE_CHECKING:
)
from torch._subclasses import fake_tensor
from torch.onnx._internal.fx import diagnostics
_DEFAULT_OPSET_VERSION: Final[int] = 18

View File

@ -2,23 +2,20 @@
from __future__ import annotations
import abc
import contextlib
import dataclasses
import difflib
import io
import logging
import sys
from typing import Any, Callable, TYPE_CHECKING
import torch
import torch.fx
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
if TYPE_CHECKING:
from torch._subclasses import fake_tensor

View File

@ -1,5 +1,6 @@
from .unsupported_nodes import UnsupportedFxNodesAnalysis
__all__ = [
"UnsupportedFxNodesAnalysis",
]

View File

@ -15,7 +15,6 @@ from __future__ import annotations
import abc
import contextlib
from typing import Callable, Sequence
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
@ -26,6 +25,7 @@ from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-
import torch
from torch._decomp import decompositions
_NEW_OP_NAMESPACE: str = "onnx_export"
"""The namespace for the custom operator."""

View File

@ -8,7 +8,6 @@ from typing import Callable
import torch
import torch._ops
import torch.fx
from torch.onnx._internal.fx import registration

View File

@ -2,9 +2,7 @@
from __future__ import annotations
import dataclasses
import functools
from typing import Any, TYPE_CHECKING
import onnxscript # type: ignore[import]
@ -17,6 +15,7 @@ from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter
from torch.onnx._internal.fx import registration, type_utils as fx_type_utils
if TYPE_CHECKING:
import logging

View File

@ -16,7 +16,6 @@ from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
import torch
import torch.fx
from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal.fx import (
_pass,
diagnostics,

View File

@ -2,16 +2,15 @@
from __future__ import annotations
import functools
from typing import Any, Callable, Mapping, Sequence
import torch
import torch.fx
import torch.onnx
import torch.onnx._internal.fx.passes as passes
from torch.onnx._internal import exporter, io_adapter
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
# data can flow through those functions. Python functions (e.g., `torch.arange`)
# not defined by pybind11 in C++ do not go though Python dispatcher, so

View File

@ -11,13 +11,13 @@ from typing import Any, Callable, Sequence, TYPE_CHECKING
import torch
import torch._ops
import torch.fx
from torch.onnx._internal.fx import (
diagnostics,
registration,
type_utils as fx_type_utils,
)
if TYPE_CHECKING:
import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]

View File

@ -13,7 +13,6 @@ import torch
import torch.fx
from torch.fx.experimental import symbolic_shapes
from torch.onnx import _constants, _type_utils as jit_type_utils
from torch.onnx._internal.fx import (
diagnostics,
fx_onnx_interpreter,

View File

@ -5,6 +5,7 @@ from .readability import RestoreParameterAndBufferNames
from .type_promotion import InsertTypePromotion
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder
__all__ = [
"Decompose",
"InsertTypePromotion",

View File

@ -6,9 +6,7 @@ These functions should NOT be directly invoked outside of `passes` package.
from __future__ import annotations
import collections
import re
from typing import Callable
import torch.fx

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import contextlib
from typing import Callable, Mapping, TYPE_CHECKING
import torch
@ -12,6 +11,7 @@ from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
if TYPE_CHECKING:
import torch.fx
from torch._subclasses import fake_tensor

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import contextlib
from typing import Callable, TYPE_CHECKING
import torch
@ -14,6 +13,7 @@ from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
from torch._subclasses import fake_tensor

View File

@ -2,19 +2,17 @@
from __future__ import annotations
import abc
import collections
import copy
import operator
from typing import Any, Dict, Final, Generator, Iterator, Sequence, Tuple
import torch
import torch.fx
from torch.onnx._internal.fx import _pass, diagnostics
from torch.utils import _pytree as pytree
_FX_TRACER_NN_MODULE_META_TYPE = Tuple[str, type]
"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer."""
_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict

View File

@ -4,7 +4,6 @@ from __future__ import annotations
from typing import Sequence
import torch
from torch.onnx._internal.fx import _pass, diagnostics

View File

@ -3,20 +3,16 @@
from __future__ import annotations
import abc
import dataclasses
import inspect
import logging
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
import torch
import torch._ops
import torch.fx
import torch.fx.traceback as fx_traceback
from torch import _prims_common, _refs
from torch._prims_common import (
ELEMENTWISE_TYPE_PROMOTION_KIND,
wrappers as _prims_common_wrappers,
@ -24,15 +20,16 @@ from torch._prims_common import (
from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
from torch._refs.nn import functional as _functional_refs
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
from torch.utils import _python_dispatch, _pytree
if TYPE_CHECKING:
from types import ModuleType
from torch._subclasses import fake_tensor
logger = logging.getLogger(__name__)
# TODO(bowbao): move to type utils.

View File

@ -4,9 +4,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch.onnx._internal.fx import _pass
if TYPE_CHECKING:
import torch.fx

View File

@ -5,6 +5,7 @@ from typing import List, TYPE_CHECKING, Union
import torch
if TYPE_CHECKING:
import io
@ -16,7 +17,6 @@ def has_safetensors_and_transformers():
# safetensors is not an exporter requirement, but needed for some huggingface models
import safetensors # type: ignore[import] # noqa: F401
import transformers # type: ignore[import] # noqa: F401
from safetensors import torch as safetensors_torch # noqa: F401
return True

View File

@ -13,6 +13,7 @@ import torch.fx
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.diagnostics import infra
if TYPE_CHECKING:
import torch.onnx
from torch.export.exported_program import ExportedProgram

View File

@ -15,11 +15,13 @@ from typing import (
)
import numpy
import onnx
import torch
from torch._subclasses import fake_tensor
if TYPE_CHECKING:
import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004

View File

@ -13,9 +13,9 @@ from typing import (
import torch
import torch.export as torch_export
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
import inspect

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs
"""Utilities for manipulating the torch.Graph object and the torchscript."""
from __future__ import annotations
# TODO(justinchuby): Move more of the symbolic helper functions here and expose
# them to the user.
from __future__ import annotations
import dataclasses
import re
import typing

View File

@ -4,7 +4,6 @@ import dataclasses
import importlib
import logging
import os
from typing import (
Any,
Callable,
@ -33,6 +32,7 @@ from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.utils import _pytree
if TYPE_CHECKING:
import onnx
@ -929,7 +929,7 @@ class OrtBackend:
try:
from onnxscript import optimizer # type: ignore[import]
from onnxscript.rewriter import ( # type: ignore[import]
onnxruntime as ort_rewriter, # type: ignore[import]
onnxruntime as ort_rewriter,
)
onnx_model = optimizer.optimize(onnx_model)
@ -1112,7 +1112,6 @@ class OrtBackend:
the ``compile`` method is invoked directly."""
if self._options.use_aot_autograd:
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
return aot_autograd(

View File

@ -10,6 +10,7 @@ import torch
from torch._C import _onnx as _C_onnx
from torch.onnx import errors
if typing.TYPE_CHECKING:
# Hack to help mypy to recognize torch._C.Value
from torch import _C # noqa: F401

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
from torch.onnx import _constants
from torch.onnx._internal import diagnostics
if TYPE_CHECKING:
from torch import _C

View File

@ -19,6 +19,7 @@ from torch.onnx import _constants, _type_utils, errors, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils
if typing.TYPE_CHECKING:
from torch.types import Number

View File

@ -23,6 +23,7 @@ from torch.onnx import (
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md

View File

@ -21,6 +21,7 @@ from torch.onnx import (
)
from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md

View File

@ -25,6 +25,7 @@ from torch.onnx import _constants, _type_utils, symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
__all__ = [
"hardswish",
"tril",

View File

@ -33,6 +33,7 @@ from torch import _C
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)

View File

@ -36,6 +36,7 @@ from torch.nn.functional import (
from torch.onnx import _type_utils, errors, symbolic_helper, utils
from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)

View File

@ -25,6 +25,7 @@ from torch import _C
from torch.onnx import _type_utils, errors, symbolic_helper
from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md

View File

@ -27,6 +27,7 @@ from torch import _C
from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

View File

@ -26,6 +26,7 @@ Size
from typing import List
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

View File

@ -24,11 +24,11 @@ New operators:
import functools
import torch.nn.functional as F
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

View File

@ -39,6 +39,7 @@ from torch._C import _onnx as _C_onnx
from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
block_listed_operators = (

View File

@ -27,6 +27,7 @@ from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_h
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration
if TYPE_CHECKING:
from torch.types import Number

View File

@ -30,6 +30,7 @@ from torch.onnx._globals import GLOBALS
from torch.onnx._internal import onnx_proto_utils
from torch.types import Number
_ORT_PROVIDERS = ("CPUExecutionProvider",)
_NumericType = Union[Number, torch.Tensor, np.ndarray]

View File

@ -23,6 +23,7 @@ from torch.optim.rprop import Rprop
from torch.optim.sgd import SGD
from torch.optim.sparse_adam import SparseAdam
Adafactor.__module__ = "torch.optim"

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_disable_dynamo_if_unsupported,
_get_scalar_dtype,
@ -12,6 +13,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["Adafactor", "adafactor"]

View File

@ -20,6 +20,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["Adadelta", "adadelta"]

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_default_to_fused_or_foreach,
_differentiable_doc,
@ -17,6 +18,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["Adagrad", "adagrad"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
@ -24,6 +25,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["Adam", "adam"]

View File

@ -21,6 +21,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["Adamax", "adamax"]

View File

@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
@ -24,6 +25,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["AdamW", "adamw"]

View File

@ -21,6 +21,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["ASGD", "asgd"]

View File

@ -3,8 +3,10 @@ from typing import Optional, Union
import torch
from torch import Tensor
from .optimizer import Optimizer, ParamsT
__all__ = ["LBFGS"]

View File

@ -26,6 +26,7 @@ from torch import inf, Tensor
from .optimizer import Optimizer
__all__ = [
"LambdaLR",
"MultiplicativeLR",

View File

@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
@ -22,6 +23,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["NAdam", "nadam"]

View File

@ -35,6 +35,7 @@ from torch.utils._foreach_utils import (
)
from torch.utils.hooks import RemovableHandle
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]

View File

@ -22,6 +22,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["RAdam", "radam"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
@ -20,6 +21,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["RMSprop", "rmsprop"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
@ -20,6 +21,7 @@ from .optimizer import (
ParamsT,
)
__all__ = ["Rprop", "rprop"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_default_to_fused_or_foreach,
_differentiable_doc,
@ -16,6 +17,7 @@ from .optimizer import (
Optimizer,
)
__all__ = ["SGD", "sgd"]

View File

@ -3,9 +3,11 @@ from typing import List, Tuple, Union
import torch
from torch import Tensor
from . import _functional as F
from .optimizer import _maximize_doc, Optimizer, ParamsT
__all__ = ["SparseAdam"]

View File

@ -11,8 +11,10 @@ from torch import Tensor
from torch.nn import Module
from torch.optim.lr_scheduler import _format_param, LRScheduler
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
from .optimizer import Optimizer
__all__ = [
"AveragedModel",
"update_bn",
@ -25,6 +27,7 @@ __all__ = [
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]

View File

@ -6,6 +6,7 @@ from typing import cast
import torch
from torch.types import Storage
__serialization_id_record_name__ = ".data/serialization_id"

View File

@ -2,6 +2,7 @@
import _warnings
import os.path
# note: implementations
# copied from cpython's import code

View File

@ -4,6 +4,7 @@ See mangling.md for details.
"""
import re
_mangle_index = 0

View File

@ -1,6 +1,7 @@
from typing import Dict, List
from ..package_exporter import PackagingError
from torch.package.package_exporter import PackagingError
__all__ = ["find_first_use_of_broken_modules"]

View File

@ -2,6 +2,7 @@
import sys
from typing import Any, Callable, Iterable, List, Tuple
__all__ = ["trace_dependencies"]

View File

@ -3,6 +3,7 @@ from typing import Dict, List
from .glob_group import GlobGroup, GlobPattern
__all__ = ["Directory"]

View File

@ -2,6 +2,7 @@
import re
from typing import Iterable, Union
GlobPattern = Union[str, Iterable[str]]

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import importlib
from abc import ABC, abstractmethod
from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined]
from pickle import ( # type: ignore[attr-defined]
_getattribute,
_Pickler,
whichmodule as _pickle_whichmodule,
@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
from ._mangling import demangle, get_mangle_prefix, is_mangled
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]

View File

@ -39,6 +39,7 @@ from .find_file_dependencies import find_files_source_depends_on
from .glob_group import GlobGroup, GlobPattern
from .importer import Importer, OrderedImporter, sys_importer
__all__ = [
"PackagingErrorReason",
"EmptyMatchError",

View File

@ -38,6 +38,7 @@ from ._package_unpickler import PackageUnpickler
from .file_structure_representation import _create_directory_from_file_list, Directory
from .importer import Importer
if TYPE_CHECKING:
from .glob_group import GlobPattern

View File

@ -25,6 +25,7 @@ from .profiler import (
tensorboard_trace_handler,
)
__all__ = [
"profile",
"schedule",

View File

@ -32,6 +32,7 @@ from torch._C._profiler import (
from torch._utils import _element_size
from torch.profiler import _utils
KeyAndID = Tuple["Key", int]
TensorAndID = Tuple["TensorKey", int]

View File

@ -7,9 +7,9 @@ from dataclasses import dataclass
from typing import Dict, List, TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
if TYPE_CHECKING:
from torch.autograd import _KinetoEvent

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
try:
from torch._C import _itt
except ImportError:

View File

@ -1,16 +1,14 @@
# mypy: allow-untyped-defs
from .quantize import * # noqa: F403
from .observer import * # noqa: F403
from .qconfig import * # noqa: F403
from .fake_quantize import * # noqa: F403
from .fuse_modules import fuse_modules
from .stubs import * # noqa: F403
from .quant_type import * # noqa: F403
from .quantize_jit import * # noqa: F403
# from .quantize_fx import *
from .quantization_mappings import * # noqa: F403
from .fuser_method_mappings import * # noqa: F403
from .observer import * # noqa: F403
from .qconfig import * # noqa: F403
from .quant_type import * # noqa: F403
from .quantization_mappings import * # noqa: F403
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
def default_eval_fn(model, calib_data):

View File

@ -15,6 +15,7 @@ from torch.ao.quantization.fx.pattern_utils import (
QuantizeHandler,
)
# QuantizeHandler.__module__ = _NAMESPACE
_register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils"
get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils"

View File

@ -23,6 +23,7 @@ from torch.ao.quantization.fx.quantize_handler import (
StandaloneModuleQuantizeHandler,
)
QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"
BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"
CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"

View File

@ -15,6 +15,7 @@ from .semi_structured import (
to_sparse_semi_structured,
)
if TYPE_CHECKING:
from torch.types import _dtype as DType

View File

@ -3,6 +3,7 @@ import contextlib
import torch
__all__ = [
"fallback_dispatcher",
"semi_sparse_values",

View File

@ -8,8 +8,10 @@ from typing import Optional, Tuple
import torch
from torch.utils._triton import has_triton
from ._triton_ops_meta import get_meta
TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(
os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2)
)

View File

@ -20,6 +20,7 @@ from torch.sparse._semi_structured_ops import (
semi_sparse_view,
)
__all__ = [
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",

View File

@ -1,4 +1,5 @@
from torch._C import FileCheck as FileCheck
from . import _utils
from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor

View File

@ -9,6 +9,7 @@ from typing import cast, List, Optional, Tuple, Union
import torch
_INTEGRAL_TYPES = [
torch.uint8,
torch.int8,

View File

@ -8,7 +8,6 @@ from typing import Any, Dict
import torch
import torch.nn as nn
from torch.distributed._sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor

View File

@ -3,6 +3,7 @@ import logging
import os
import sys
# NOTE: [dynamo_test_failures.py]
#
# We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures`

View File

@ -34,9 +34,9 @@ from torch.testing._internal.common_utils import (
TrackedInputIter,
)
from torch.testing._internal.opinfo import utils
from torchgen.utils import dataclass_repr
# Reasonable testing sizes for dimensions
L = 20
M = 10

View File

@ -11,6 +11,7 @@ from torch.testing._internal.opinfo.definitions import (
special,
)
# Operator database
op_db: List[OpInfo] = [
*fft.op_db,

View File

@ -7,7 +7,6 @@ from typing import List
import numpy as np
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import SM53OrLater
from torch.testing._internal.common_device_type import precisionOverride
@ -31,6 +30,7 @@ from torch.testing._internal.opinfo.refs import (
PythonRefInfo,
)
has_scipy_fft = False
if TEST_SCIPY:
try:

View File

@ -11,7 +11,6 @@ import numpy as np
from numpy import inf
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
_get_magma_version,

View File

@ -2,7 +2,6 @@
import unittest
from functools import partial
from itertools import product
from typing import Callable, List, Tuple
@ -18,6 +17,7 @@ from torch.testing._internal.opinfo.core import (
SampleInput,
)
if TEST_SCIPY:
import scipy.signal

View File

@ -7,6 +7,7 @@ from torch.testing._internal.opinfo.core import (
UnaryUfuncInfo,
)
# NOTE [Python References]
# Python References emulate existing PyTorch operations, but can ultimately
# be expressed in terms of "primitive" operations from torch._prims.

Some files were not shown because too many files have changed in this diff Show More