[BE][Easy] enable UFMT for torch/nn/ (#128865)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128865
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2024-07-22 14:58:12 +08:00
committed by PyTorch MergeBot
parent 8ea4c72eb2
commit b5c006acac
53 changed files with 517 additions and 459 deletions

View File

@ -1476,59 +1476,6 @@ exclude_patterns = [
'torch/linalg/__init__.py', 'torch/linalg/__init__.py',
'torch/monitor/__init__.py', 'torch/monitor/__init__.py',
'torch/nested/__init__.py', 'torch/nested/__init__.py',
'torch/nn/intrinsic/__init__.py',
'torch/nn/intrinsic/modules/__init__.py',
'torch/nn/intrinsic/modules/fused.py',
'torch/nn/intrinsic/qat/__init__.py',
'torch/nn/intrinsic/qat/modules/__init__.py',
'torch/nn/intrinsic/qat/modules/conv_fused.py',
'torch/nn/intrinsic/qat/modules/linear_fused.py',
'torch/nn/intrinsic/qat/modules/linear_relu.py',
'torch/nn/intrinsic/quantized/__init__.py',
'torch/nn/intrinsic/quantized/dynamic/__init__.py',
'torch/nn/intrinsic/quantized/dynamic/modules/__init__.py',
'torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py',
'torch/nn/intrinsic/quantized/modules/__init__.py',
'torch/nn/intrinsic/quantized/modules/bn_relu.py',
'torch/nn/intrinsic/quantized/modules/conv_relu.py',
'torch/nn/intrinsic/quantized/modules/linear_relu.py',
'torch/nn/qat/__init__.py',
'torch/nn/qat/dynamic/__init__.py',
'torch/nn/qat/dynamic/modules/__init__.py',
'torch/nn/qat/dynamic/modules/linear.py',
'torch/nn/qat/modules/__init__.py',
'torch/nn/qat/modules/conv.py',
'torch/nn/qat/modules/embedding_ops.py',
'torch/nn/qat/modules/linear.py',
'torch/nn/quantizable/__init__.py',
'torch/nn/quantizable/modules/__init__.py',
'torch/nn/quantizable/modules/activation.py',
'torch/nn/quantizable/modules/rnn.py',
'torch/nn/quantized/__init__.py',
'torch/nn/quantized/_reference/__init__.py',
'torch/nn/quantized/_reference/modules/__init__.py',
'torch/nn/quantized/_reference/modules/conv.py',
'torch/nn/quantized/_reference/modules/linear.py',
'torch/nn/quantized/_reference/modules/rnn.py',
'torch/nn/quantized/_reference/modules/sparse.py',
'torch/nn/quantized/_reference/modules/utils.py',
'torch/nn/quantized/dynamic/__init__.py',
'torch/nn/quantized/dynamic/modules/__init__.py',
'torch/nn/quantized/dynamic/modules/conv.py',
'torch/nn/quantized/dynamic/modules/linear.py',
'torch/nn/quantized/dynamic/modules/rnn.py',
'torch/nn/quantized/functional.py',
'torch/nn/quantized/modules/__init__.py',
'torch/nn/quantized/modules/activation.py',
'torch/nn/quantized/modules/batchnorm.py',
'torch/nn/quantized/modules/conv.py',
'torch/nn/quantized/modules/dropout.py',
'torch/nn/quantized/modules/embedding_ops.py',
'torch/nn/quantized/modules/functional_modules.py',
'torch/nn/quantized/modules/linear.py',
'torch/nn/quantized/modules/normalization.py',
'torch/nn/quantized/modules/rnn.py',
'torch/nn/quantized/modules/utils.py',
'torch/signal/__init__.py', 'torch/signal/__init__.py',
'torch/signal/windows/__init__.py', 'torch/signal/windows/__init__.py',
'torch/signal/windows/windows.py', 'torch/signal/windows/windows.py',

View File

@ -1,35 +1,36 @@
from torch.ao.nn.intrinsic import ConvBn1d from torch.ao.nn.intrinsic import (
from torch.ao.nn.intrinsic import ConvBn2d BNReLU2d,
from torch.ao.nn.intrinsic import ConvBn3d BNReLU3d,
from torch.ao.nn.intrinsic import ConvBnReLU1d ConvBn1d,
from torch.ao.nn.intrinsic import ConvBnReLU2d ConvBn2d,
from torch.ao.nn.intrinsic import ConvBnReLU3d ConvBn3d,
from torch.ao.nn.intrinsic import ConvReLU1d ConvBnReLU1d,
from torch.ao.nn.intrinsic import ConvReLU2d ConvBnReLU2d,
from torch.ao.nn.intrinsic import ConvReLU3d ConvBnReLU3d,
from torch.ao.nn.intrinsic import LinearReLU ConvReLU1d,
from torch.ao.nn.intrinsic import BNReLU2d ConvReLU2d,
from torch.ao.nn.intrinsic import BNReLU3d ConvReLU3d,
from torch.ao.nn.intrinsic import LinearBn1d LinearBn1d,
LinearReLU,
)
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401 from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
# Include the subpackages in case user imports from it directly # Include the subpackages in case user imports from it directly
from . import modules # noqa: F401 from torch.nn.intrinsic import modules, qat, quantized # noqa: F401
from . import qat # noqa: F401
from . import quantized # noqa: F401
__all__ = [ __all__ = [
'ConvBn1d', "ConvBn1d",
'ConvBn2d', "ConvBn2d",
'ConvBn3d', "ConvBn3d",
'ConvBnReLU1d', "ConvBnReLU1d",
'ConvBnReLU2d', "ConvBnReLU2d",
'ConvBnReLU3d', "ConvBnReLU3d",
'ConvReLU1d', "ConvReLU1d",
'ConvReLU2d', "ConvReLU2d",
'ConvReLU3d', "ConvReLU3d",
'LinearReLU', "LinearReLU",
'BNReLU2d', "BNReLU2d",
'BNReLU3d', "BNReLU3d",
'LinearBn1d', "LinearBn1d",
] ]

View File

@ -1,31 +1,33 @@
from .fused import _FusedModule # noqa: F401 from torch.nn.intrinsic.modules.fused import (
from .fused import BNReLU2d _FusedModule,
from .fused import BNReLU3d BNReLU2d,
from .fused import ConvBn1d BNReLU3d,
from .fused import ConvBn2d ConvBn1d,
from .fused import ConvBn3d ConvBn2d,
from .fused import ConvBnReLU1d ConvBn3d,
from .fused import ConvBnReLU2d ConvBnReLU1d,
from .fused import ConvBnReLU3d ConvBnReLU2d,
from .fused import ConvReLU1d ConvBnReLU3d,
from .fused import ConvReLU2d ConvReLU1d,
from .fused import ConvReLU3d ConvReLU2d,
from .fused import LinearBn1d ConvReLU3d,
from .fused import LinearReLU LinearBn1d,
LinearReLU,
)
__all__ = [ __all__ = [
'BNReLU2d', "BNReLU2d",
'BNReLU3d', "BNReLU3d",
'ConvBn1d', "ConvBn1d",
'ConvBn2d', "ConvBn2d",
'ConvBn3d', "ConvBn3d",
'ConvBnReLU1d', "ConvBnReLU1d",
'ConvBnReLU2d', "ConvBnReLU2d",
'ConvBnReLU3d', "ConvBnReLU3d",
'ConvReLU1d', "ConvReLU1d",
'ConvReLU2d', "ConvReLU2d",
'ConvReLU3d', "ConvReLU3d",
'LinearBn1d', "LinearBn1d",
'LinearReLU', "LinearReLU",
] ]

View File

@ -1,30 +1,33 @@
from torch.ao.nn.intrinsic import BNReLU2d from torch.ao.nn.intrinsic import (
from torch.ao.nn.intrinsic import BNReLU3d BNReLU2d,
from torch.ao.nn.intrinsic import ConvBn1d BNReLU3d,
from torch.ao.nn.intrinsic import ConvBn2d ConvBn1d,
from torch.ao.nn.intrinsic import ConvBn3d ConvBn2d,
from torch.ao.nn.intrinsic import ConvBnReLU1d ConvBn3d,
from torch.ao.nn.intrinsic import ConvBnReLU2d ConvBnReLU1d,
from torch.ao.nn.intrinsic import ConvBnReLU3d ConvBnReLU2d,
from torch.ao.nn.intrinsic import ConvReLU1d ConvBnReLU3d,
from torch.ao.nn.intrinsic import ConvReLU2d ConvReLU1d,
from torch.ao.nn.intrinsic import ConvReLU3d ConvReLU2d,
from torch.ao.nn.intrinsic import LinearBn1d ConvReLU3d,
from torch.ao.nn.intrinsic import LinearReLU LinearBn1d,
LinearReLU,
)
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401 from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
__all__ = [ __all__ = [
'BNReLU2d', "BNReLU2d",
'BNReLU3d', "BNReLU3d",
'ConvBn1d', "ConvBn1d",
'ConvBn2d', "ConvBn2d",
'ConvBn3d', "ConvBn3d",
'ConvBnReLU1d', "ConvBnReLU1d",
'ConvBnReLU2d', "ConvBnReLU2d",
'ConvBnReLU3d', "ConvBnReLU3d",
'ConvReLU1d', "ConvReLU1d",
'ConvReLU2d', "ConvReLU2d",
'ConvReLU3d', "ConvReLU3d",
'LinearBn1d', "LinearBn1d",
'LinearReLU', "LinearReLU",
] ]

View File

@ -1 +1 @@
from .modules import * # noqa: F403 from torch.nn.intrinsic.qat.modules import * # noqa: F403

View File

@ -1,6 +1,4 @@
from .linear_relu import LinearReLU from torch.nn.intrinsic.qat.modules.conv_fused import (
from .linear_fused import LinearBn1d
from .conv_fused import (
ConvBn1d, ConvBn1d,
ConvBn2d, ConvBn2d,
ConvBn3d, ConvBn3d,
@ -10,9 +8,12 @@ from .conv_fused import (
ConvReLU1d, ConvReLU1d,
ConvReLU2d, ConvReLU2d,
ConvReLU3d, ConvReLU3d,
update_bn_stats,
freeze_bn_stats, freeze_bn_stats,
update_bn_stats,
) )
from torch.nn.intrinsic.qat.modules.linear_fused import LinearBn1d
from torch.nn.intrinsic.qat.modules.linear_relu import LinearReLU
__all__ = [ __all__ = [
"LinearReLU", "LinearReLU",

View File

@ -8,30 +8,33 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.intrinsic.qat import (
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
freeze_bn_stats,
update_bn_stats,
)
__all__ = [ __all__ = [
# Modules # Modules
'ConvBn1d', "ConvBn1d",
'ConvBnReLU1d', "ConvBnReLU1d",
'ConvReLU1d', "ConvReLU1d",
'ConvBn2d', "ConvBn2d",
'ConvBnReLU2d', "ConvBnReLU2d",
'ConvReLU2d', "ConvReLU2d",
'ConvBn3d', "ConvBn3d",
'ConvBnReLU3d', "ConvBnReLU3d",
'ConvReLU3d', "ConvReLU3d",
# Utilities # Utilities
'freeze_bn_stats', "freeze_bn_stats",
'update_bn_stats', "update_bn_stats",
] ]
from torch.ao.nn.intrinsic.qat import ConvBn1d
from torch.ao.nn.intrinsic.qat import ConvBnReLU1d
from torch.ao.nn.intrinsic.qat import ConvReLU1d
from torch.ao.nn.intrinsic.qat import ConvBn2d
from torch.ao.nn.intrinsic.qat import ConvBnReLU2d
from torch.ao.nn.intrinsic.qat import ConvReLU2d
from torch.ao.nn.intrinsic.qat import ConvBn3d
from torch.ao.nn.intrinsic.qat import ConvBnReLU3d
from torch.ao.nn.intrinsic.qat import ConvReLU3d
from torch.ao.nn.intrinsic.qat import freeze_bn_stats
from torch.ao.nn.intrinsic.qat import update_bn_stats

View File

@ -8,8 +8,9 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = [
'LinearBn1d',
]
from torch.ao.nn.intrinsic.qat import LinearBn1d from torch.ao.nn.intrinsic.qat import LinearBn1d
__all__ = [
"LinearBn1d",
]

View File

@ -8,8 +8,9 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = [
'LinearReLU',
]
from torch.ao.nn.intrinsic.qat import LinearReLU from torch.ao.nn.intrinsic.qat import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -1,13 +1,14 @@
from .modules import * # noqa: F403
# to ensure customers can use the module below # to ensure customers can use the module below
# without importing it directly # without importing it directly
import torch.nn.intrinsic.quantized.dynamic from torch.nn.intrinsic.quantized import dynamic, modules # noqa: F401
from torch.nn.intrinsic.quantized.modules import * # noqa: F403
__all__ = [ __all__ = [
'BNReLU2d', "BNReLU2d",
'BNReLU3d', "BNReLU3d",
'ConvReLU1d', "ConvReLU1d",
'ConvReLU2d', "ConvReLU2d",
'ConvReLU3d', "ConvReLU3d",
'LinearReLU', "LinearReLU",
] ]

View File

@ -1 +1 @@
from .modules import * # noqa: F403 from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403

View File

@ -1,5 +1,6 @@
from .linear_relu import LinearReLU from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU
__all__ = [ __all__ = [
'LinearReLU', "LinearReLU",
] ]

View File

@ -1,5 +1,6 @@
from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
__all__ = [ __all__ = [
'LinearReLU', "LinearReLU",
] ]

View File

@ -1,12 +1,17 @@
from .linear_relu import LinearReLU from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d from torch.nn.intrinsic.quantized.modules.conv_relu import (
from .bn_relu import BNReLU2d, BNReLU3d ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
)
from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU
__all__ = [ __all__ = [
'LinearReLU', "LinearReLU",
'ConvReLU1d', "ConvReLU1d",
'ConvReLU2d', "ConvReLU2d",
'ConvReLU3d', "ConvReLU3d",
'BNReLU2d', "BNReLU2d",
'BNReLU3d', "BNReLU3d",
] ]

View File

@ -1,7 +1,7 @@
from torch.ao.nn.intrinsic.quantized import BNReLU2d from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d
from torch.ao.nn.intrinsic.quantized import BNReLU3d
__all__ = [ __all__ = [
'BNReLU2d', "BNReLU2d",
'BNReLU3d', "BNReLU3d",
] ]

View File

@ -1,9 +1,8 @@
from torch.ao.nn.intrinsic.quantized import ConvReLU1d from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d
from torch.ao.nn.intrinsic.quantized import ConvReLU2d
from torch.ao.nn.intrinsic.quantized import ConvReLU3d
__all__ = [ __all__ = [
'ConvReLU1d', "ConvReLU1d",
'ConvReLU2d', "ConvReLU2d",
'ConvReLU3d', "ConvReLU3d",
] ]

View File

@ -1,5 +1,6 @@
from torch.ao.nn.intrinsic.quantized import LinearReLU from torch.ao.nn.intrinsic.quantized import LinearReLU
__all__ = [ __all__ = [
'LinearReLU', "LinearReLU",
] ]

View File

@ -1,11 +1,11 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing_extensions import deprecated from typing_extensions import deprecated
from .data_parallel import data_parallel, DataParallel from torch.nn.parallel.data_parallel import data_parallel, DataParallel
from .distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from .parallel_apply import parallel_apply from torch.nn.parallel.parallel_apply import parallel_apply
from .replicate import replicate from torch.nn.parallel.replicate import replicate
from .scatter_gather import gather, scatter from torch.nn.parallel.scatter_gather import gather, scatter
__all__ = [ __all__ = [

View File

@ -4,8 +4,7 @@ from typing import List, Optional
import torch import torch
from torch._utils import _get_device_index from torch._utils import _get_device_index
from torch.autograd import Function from torch.autograd import Function
from torch.nn.parallel import comm
from . import comm
class Broadcast(Function): class Broadcast(Function):

View File

@ -11,11 +11,10 @@ from torch._utils import (
_get_device_index, _get_device_index,
_get_devices_properties, _get_devices_properties,
) )
from torch.nn.modules import Module
from ..modules import Module from torch.nn.parallel.parallel_apply import parallel_apply
from .parallel_apply import parallel_apply from torch.nn.parallel.replicate import replicate
from .replicate import replicate from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
from .scatter_gather import gather, scatter_kwargs
__all__ = ["DataParallel", "data_parallel"] __all__ = ["DataParallel", "data_parallel"]

View File

@ -19,11 +19,10 @@ import torch.distributed as dist
from torch._utils import _get_device_index from torch._utils import _get_device_index
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import Join, Joinable, JoinHook from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.nn.modules import Module
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from ..modules import Module
from .scatter_gather import gather, scatter_kwargs
RPC_AVAILABLE = False RPC_AVAILABLE = False
if dist.is_available(): if dist.is_available():
@ -47,6 +46,7 @@ if dist.rpc.is_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
__all__ = ["DistributedDataParallel"] __all__ = ["DistributedDataParallel"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,8 +4,7 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from torch._utils import ExceptionWrapper from torch._utils import ExceptionWrapper
from torch.cuda._utils import _get_device_index from torch.cuda._utils import _get_device_index
from torch.nn.modules import Module
from ..modules import Module
__all__ = ["get_a_var", "parallel_apply"] __all__ = ["get_a_var", "parallel_apply"]

View File

@ -14,9 +14,8 @@ from typing import (
import torch import torch
from torch._utils import _get_device_index from torch._utils import _get_device_index
from torch.nn.modules import Module
from ..modules import Module from torch.nn.parallel import comm
from . import comm
if TYPE_CHECKING: if TYPE_CHECKING:
@ -94,7 +93,7 @@ def _broadcast_coalesced_reshape(
devices: Sequence[Union[int, torch.device]], devices: Sequence[Union[int, torch.device]],
detach: bool = False, detach: bool = False,
) -> List[List[torch.Tensor]]: ) -> List[List[torch.Tensor]]:
from ._functions import Broadcast from torch.nn.parallel._functions import Broadcast
if detach: if detach:
return comm.broadcast_coalesced(tensors, devices) return comm.broadcast_coalesced(tensors, devices)

View File

@ -3,8 +3,7 @@ from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar
from typing_extensions import deprecated from typing_extensions import deprecated
import torch import torch
from torch.nn.parallel._functions import Gather, Scatter
from ._functions import Gather, Scatter
__all__ = ["scatter", "scatter_kwargs", "gather"] __all__ = ["scatter", "scatter_kwargs", "gather"]

View File

@ -4,9 +4,9 @@ r"""QAT Dynamic Modules.
This package is in the process of being deprecated. This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.dynamic` instead. Please, use `torch.ao.nn.qat.dynamic` instead.
""" """
from . import dynamic # noqa: F403 from torch.nn.qat import dynamic, modules # noqa: F403
from . import modules # noqa: F403 from torch.nn.qat.modules import * # noqa: F403
from .modules import * # noqa: F403
__all__ = [ __all__ = [
"Linear", "Linear",

View File

@ -4,4 +4,4 @@ r"""QAT Dynamic Modules.
This package is in the process of being deprecated. This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.dynamic` instead. Please, use `torch.ao.nn.qat.dynamic` instead.
""" """
from .modules import * # noqa: F403 from torch.nn.qat.dynamic.modules import * # noqa: F403

View File

@ -1,3 +1,4 @@
from .linear import Linear from torch.nn.qat.dynamic.modules.linear import Linear
__all__ = ["Linear"] __all__ = ["Linear"]

View File

@ -4,15 +4,11 @@ r"""QAT Modules.
This package is in the process of being deprecated. This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.modules` instead. Please, use `torch.ao.nn.qat.modules` instead.
""" """
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.qat.modules.linear import Linear from torch.ao.nn.qat.modules.linear import Linear
from torch.ao.nn.qat.modules.conv import Conv1d from torch.nn.qat.modules import conv, embedding_ops, linear
from torch.ao.nn.qat.modules.conv import Conv2d
from torch.ao.nn.qat.modules.conv import Conv3d
from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag, Embedding
from . import conv
from . import embedding_ops
from . import linear
__all__ = [ __all__ = [
"Linear", "Linear",

View File

@ -7,6 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/qat/modules`, appropriate file under the `torch/ao/nn/qat/modules`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.qat.modules.conv import Conv1d
from torch.ao.nn.qat.modules.conv import Conv2d from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
from torch.ao.nn.qat.modules.conv import Conv3d

View File

@ -8,7 +8,7 @@ appropriate file under the `torch/ao/nn/qat/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['Embedding', 'EmbeddingBag'] from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.qat.modules.embedding_ops import Embedding
from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag __all__ = ["Embedding", "EmbeddingBag"]

View File

@ -1 +1 @@
from .modules import * # noqa: F403 from torch.nn.quantizable.modules import * # noqa: F403

View File

@ -1,9 +1,9 @@
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
from torch.ao.nn.quantizable.modules.rnn import LSTM from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
__all__ = [ __all__ = [
'LSTM', "LSTM",
'LSTMCell', "LSTMCell",
'MultiheadAttention', "MultiheadAttention",
] ]

View File

@ -7,5 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantizable/modules`, appropriate file under the `torch/ao/nn/quantizable/modules`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantizable.modules.rnn import LSTM
from torch.ao.nn.quantizable.modules.rnn import LSTMCell from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell

View File

@ -1,40 +1,39 @@
from . import dynamic # noqa: F403 from torch.nn.quantized import dynamic, functional, modules # noqa: F403
from . import functional # noqa: F403 from torch.nn.quantized.modules import * # noqa: F403
from . import modules # noqa: F403 from torch.nn.quantized.modules import MaxPool2d
from .modules import * # noqa: F403
from .modules import MaxPool2d
__all__ = [ __all__ = [
'BatchNorm2d', "BatchNorm2d",
'BatchNorm3d', "BatchNorm3d",
'Conv1d', "Conv1d",
'Conv2d', "Conv2d",
'Conv3d', "Conv3d",
'ConvTranspose1d', "ConvTranspose1d",
'ConvTranspose2d', "ConvTranspose2d",
'ConvTranspose3d', "ConvTranspose3d",
'DeQuantize', "DeQuantize",
'Dropout', "Dropout",
'ELU', "ELU",
'Embedding', "Embedding",
'EmbeddingBag', "EmbeddingBag",
'GroupNorm', "GroupNorm",
'Hardswish', "Hardswish",
'InstanceNorm1d', "InstanceNorm1d",
'InstanceNorm2d', "InstanceNorm2d",
'InstanceNorm3d', "InstanceNorm3d",
'LayerNorm', "LayerNorm",
'LeakyReLU', "LeakyReLU",
'Linear', "Linear",
'LSTM', "LSTM",
'MultiheadAttention', "MultiheadAttention",
'PReLU', "PReLU",
'Quantize', "Quantize",
'ReLU6', "ReLU6",
'Sigmoid', "Sigmoid",
'Softmax', "Softmax",
# Wrapper modules # Wrapper modules
'FloatFunctional', "FloatFunctional",
'FXFloatFunctional', "FXFloatFunctional",
'QFunctional', "QFunctional",
] ]

View File

@ -1 +1 @@
from .modules import * # noqa: F403 from torch.nn.quantized._reference.modules import * # noqa: F403

View File

@ -9,23 +9,31 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.reference.modules.conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.reference.modules.linear import Linear from torch.ao.nn.quantized.reference.modules.linear import Linear
from torch.ao.nn.quantized.reference.modules.conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from torch.ao.nn.quantized.reference.modules.rnn import GRUCell, LSTM, LSTMCell, RNNCell
from torch.ao.nn.quantized.reference.modules.rnn import RNNCell, LSTMCell, GRUCell, LSTM
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
__all__ = [ __all__ = [
'Linear', "Linear",
'Conv1d', "Conv1d",
'Conv2d', "Conv2d",
'Conv3d', "Conv3d",
'ConvTranspose1d', "ConvTranspose1d",
'ConvTranspose2d', "ConvTranspose2d",
'ConvTranspose3d', "ConvTranspose3d",
'RNNCell', "RNNCell",
'LSTMCell', "LSTMCell",
'GRUCell', "GRUCell",
'LSTM', "LSTM",
'Embedding', "Embedding",
'EmbeddingBag', "EmbeddingBag",
] ]

View File

@ -9,11 +9,13 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.reference.modules.conv import _ConvNd from torch.ao.nn.quantized.reference.modules.conv import (
from torch.ao.nn.quantized.reference.modules.conv import Conv1d _ConvNd,
from torch.ao.nn.quantized.reference.modules.conv import Conv2d _ConvTransposeNd,
from torch.ao.nn.quantized.reference.modules.conv import Conv3d Conv1d,
from torch.ao.nn.quantized.reference.modules.conv import _ConvTransposeNd Conv2d,
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose1d Conv3d,
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose2d ConvTranspose1d,
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose3d ConvTranspose2d,
ConvTranspose3d,
)

View File

@ -9,9 +9,11 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.reference.modules.rnn import RNNCellBase from torch.ao.nn.quantized.reference.modules.rnn import (
from torch.ao.nn.quantized.reference.modules.rnn import RNNCell GRUCell,
from torch.ao.nn.quantized.reference.modules.rnn import LSTMCell LSTM,
from torch.ao.nn.quantized.reference.modules.rnn import GRUCell LSTMCell,
from torch.ao.nn.quantized.reference.modules.rnn import RNNBase RNNBase,
from torch.ao.nn.quantized.reference.modules.rnn import LSTM RNNCell,
RNNCellBase,
)

View File

@ -9,5 +9,4 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.reference.modules.sparse import Embedding from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
from torch.ao.nn.quantized.reference.modules.sparse import EmbeddingBag

View File

@ -8,8 +8,11 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantized/reference`, appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.reference.modules.utils import _quantize_weight
from torch.ao.nn.quantized.reference.modules.utils import _quantize_and_dequantize_weight from torch.ao.nn.quantized.reference.modules.utils import (
from torch.ao.nn.quantized.reference.modules.utils import _save_weight_qparams _get_weight_qparam_keys,
from torch.ao.nn.quantized.reference.modules.utils import _get_weight_qparam_keys _quantize_and_dequantize_weight,
from torch.ao.nn.quantized.reference.modules.utils import ReferenceQuantizedModule _quantize_weight,
_save_weight_qparams,
ReferenceQuantizedModule,
)

View File

@ -8,25 +8,36 @@ appropriate file under the `torch/ao/nn/quantized/dynamic`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.dynamic.modules import conv from torch.ao.nn.quantized.dynamic.modules import conv, linear, rnn
from torch.ao.nn.quantized.dynamic.modules import linear from torch.ao.nn.quantized.dynamic.modules.conv import (
from torch.ao.nn.quantized.dynamic.modules import rnn Conv1d,
Conv2d,
from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.dynamic.modules.linear import Linear from torch.ao.nn.quantized.dynamic.modules.linear import Linear
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM, GRU, LSTMCell, RNNCell, GRUCell from torch.ao.nn.quantized.dynamic.modules.rnn import (
GRU,
GRUCell,
LSTM,
LSTMCell,
RNNCell,
)
__all__ = [ __all__ = [
'Linear', "Linear",
'LSTM', "LSTM",
'GRU', "GRU",
'LSTMCell', "LSTMCell",
'RNNCell', "RNNCell",
'GRUCell', "GRUCell",
'Conv1d', "Conv1d",
'Conv2d', "Conv2d",
'Conv3d', "Conv3d",
'ConvTranspose1d', "ConvTranspose1d",
'ConvTranspose2d', "ConvTranspose2d",
'ConvTranspose3d', "ConvTranspose3d",
] ]

View File

@ -8,11 +8,21 @@ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] from torch.ao.nn.quantized.dynamic.modules.conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d
from torch.ao.nn.quantized.dynamic.modules.conv import Conv2d __all__ = [
from torch.ao.nn.quantized.dynamic.modules.conv import Conv3d "Conv1d",
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose1d "Conv2d",
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose2d "Conv3d",
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose3d "ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]

View File

@ -8,15 +8,27 @@ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', from torch.ao.nn.quantized.dynamic.modules.rnn import (
'GRUCell'] GRU,
GRUCell,
LSTM,
LSTMCell,
pack_weight_bias,
PackedParameter,
RNNBase,
RNNCell,
RNNCellBase,
)
from torch.ao.nn.quantized.dynamic.modules.rnn import pack_weight_bias
from torch.ao.nn.quantized.dynamic.modules.rnn import PackedParameter __all__ = [
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNBase "pack_weight_bias",
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM "PackedParameter",
from torch.ao.nn.quantized.dynamic.modules.rnn import GRU "RNNBase",
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCellBase "LSTM",
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCell "GRU",
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTMCell "RNNCellBase",
from torch.ao.nn.quantized.dynamic.modules.rnn import GRUCell "RNNCell",
"LSTMCell",
"GRUCell",
]

View File

@ -5,66 +5,93 @@ Note::
Please, use `torch.ao.nn.quantized` instead. Please, use `torch.ao.nn.quantized` instead.
""" """
from torch.ao.nn.quantized.modules.activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
from torch.ao.nn.quantized.modules.conv import Conv1d, Conv2d, Conv3d
from torch.ao.nn.quantized.modules.conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
from torch.ao.nn.quantized.modules.dropout import Dropout
from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional, FXFloatFunctional, QFunctional
from torch.ao.nn.quantized.modules.linear import Linear
from torch.ao.nn.quantized.modules.normalization import LayerNorm, GroupNorm, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
from torch.ao.nn.quantized.modules.rnn import LSTM
from torch.ao.nn.quantized.modules import MaxPool2d
from torch.ao.nn.quantized.modules import Quantize, DeQuantize
# The following imports are needed in case the user decides # The following imports are needed in case the user decides
# to import the files directly, # to import the files directly,
# s.a. `from torch.nn.quantized.modules.conv import ...`. # s.a. `from torch.nn.quantized.modules.conv import ...`.
# No need to add them to the `__all__`. # No need to add them to the `__all__`.
from torch.ao.nn.quantized.modules import activation from torch.ao.nn.quantized.modules import (
from torch.ao.nn.quantized.modules import batchnorm activation,
from torch.ao.nn.quantized.modules import conv batchnorm,
from torch.ao.nn.quantized.modules import dropout conv,
from torch.ao.nn.quantized.modules import embedding_ops DeQuantize,
from torch.ao.nn.quantized.modules import functional_modules dropout,
from torch.ao.nn.quantized.modules import linear embedding_ops,
from torch.ao.nn.quantized.modules import normalization functional_modules,
from torch.ao.nn.quantized.modules import rnn linear,
from torch.ao.nn.quantized.modules import utils MaxPool2d,
normalization,
Quantize,
rnn,
utils,
)
from torch.ao.nn.quantized.modules.activation import (
ELU,
Hardswish,
LeakyReLU,
MultiheadAttention,
PReLU,
ReLU6,
Sigmoid,
Softmax,
)
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
from torch.ao.nn.quantized.modules.conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.modules.dropout import Dropout
from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.quantized.modules.functional_modules import (
FloatFunctional,
FXFloatFunctional,
QFunctional,
)
from torch.ao.nn.quantized.modules.linear import Linear
from torch.ao.nn.quantized.modules.normalization import (
GroupNorm,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LayerNorm,
)
from torch.ao.nn.quantized.modules.rnn import LSTM
__all__ = [ __all__ = [
'BatchNorm2d', "BatchNorm2d",
'BatchNorm3d', "BatchNorm3d",
'Conv1d', "Conv1d",
'Conv2d', "Conv2d",
'Conv3d', "Conv3d",
'ConvTranspose1d', "ConvTranspose1d",
'ConvTranspose2d', "ConvTranspose2d",
'ConvTranspose3d', "ConvTranspose3d",
'DeQuantize', "DeQuantize",
'ELU', "ELU",
'Embedding', "Embedding",
'EmbeddingBag', "EmbeddingBag",
'GroupNorm', "GroupNorm",
'Hardswish', "Hardswish",
'InstanceNorm1d', "InstanceNorm1d",
'InstanceNorm2d', "InstanceNorm2d",
'InstanceNorm3d', "InstanceNorm3d",
'LayerNorm', "LayerNorm",
'LeakyReLU', "LeakyReLU",
'Linear', "Linear",
'LSTM', "LSTM",
'MultiheadAttention', "MultiheadAttention",
'Quantize', "Quantize",
'ReLU6', "ReLU6",
'Sigmoid', "Sigmoid",
'Softmax', "Softmax",
'Dropout', "Dropout",
'PReLU', "PReLU",
# Wrapper modules # Wrapper modules
'FloatFunctional', "FloatFunctional",
'FXFloatFunctional', "FXFloatFunctional",
'QFunctional', "QFunctional",
] ]

View File

@ -8,11 +8,13 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.modules.activation import ELU from torch.ao.nn.quantized.modules.activation import (
from torch.ao.nn.quantized.modules.activation import Hardswish ELU,
from torch.ao.nn.quantized.modules.activation import LeakyReLU Hardswish,
from torch.ao.nn.quantized.modules.activation import MultiheadAttention LeakyReLU,
from torch.ao.nn.quantized.modules.activation import PReLU MultiheadAttention,
from torch.ao.nn.quantized.modules.activation import ReLU6 PReLU,
from torch.ao.nn.quantized.modules.activation import Sigmoid ReLU6,
from torch.ao.nn.quantized.modules.activation import Softmax Sigmoid,
Softmax,
)

View File

@ -8,5 +8,4 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm3d

View File

@ -8,14 +8,22 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] from torch.ao.nn.quantized.modules.conv import (
_reverse_repeat_padding,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
from torch.ao.nn.quantized.modules.conv import Conv1d __all__ = [
from torch.ao.nn.quantized.modules.conv import Conv2d "Conv1d",
from torch.ao.nn.quantized.modules.conv import Conv3d "Conv2d",
"Conv3d",
from torch.ao.nn.quantized.modules.conv import ConvTranspose1d "ConvTranspose1d",
from torch.ao.nn.quantized.modules.conv import ConvTranspose2d "ConvTranspose2d",
from torch.ao.nn.quantized.modules.conv import ConvTranspose3d "ConvTranspose3d",
]

View File

@ -8,6 +8,7 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['Dropout']
from torch.ao.nn.quantized.modules.dropout import Dropout from torch.ao.nn.quantized.modules.dropout import Dropout
__all__ = ["Dropout"]

View File

@ -8,8 +8,11 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag'] from torch.ao.nn.quantized.modules.embedding_ops import (
Embedding,
EmbeddingBag,
EmbeddingPackedParams,
)
from torch.ao.nn.quantized.modules.embedding_ops import Embedding
from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingBag __all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingPackedParams

View File

@ -8,8 +8,11 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional'] from torch.ao.nn.quantized.modules.functional_modules import (
FloatFunctional,
FXFloatFunctional,
QFunctional,
)
from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional
from torch.ao.nn.quantized.modules.functional_modules import FXFloatFunctional __all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]
from torch.ao.nn.quantized.modules.functional_modules import QFunctional

View File

@ -8,7 +8,7 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['LinearPackedParams', 'Linear'] from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams
from torch.ao.nn.quantized.modules.linear import Linear
from torch.ao.nn.quantized.modules.linear import LinearPackedParams __all__ = ["LinearPackedParams", "Linear"]

View File

@ -8,10 +8,19 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] from torch.ao.nn.quantized.modules.normalization import (
GroupNorm,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LayerNorm,
)
from torch.ao.nn.quantized.modules.normalization import LayerNorm
from torch.ao.nn.quantized.modules.normalization import GroupNorm __all__ = [
from torch.ao.nn.quantized.modules.normalization import InstanceNorm1d "LayerNorm",
from torch.ao.nn.quantized.modules.normalization import InstanceNorm2d "GroupNorm",
from torch.ao.nn.quantized.modules.normalization import InstanceNorm3d "InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
]

View File

@ -8,8 +8,10 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here. while adding an import statement here.
""" """
from torch.ao.nn.quantized.modules.utils import _ntuple_from_first from torch.ao.nn.quantized.modules.utils import (
from torch.ao.nn.quantized.modules.utils import _pair_from_first _hide_packed_params_repr,
from torch.ao.nn.quantized.modules.utils import _quantize_weight _ntuple_from_first,
from torch.ao.nn.quantized.modules.utils import _hide_packed_params_repr _pair_from_first,
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule _quantize_weight,
WeightedQuantizedModule,
)