[BE][Easy] enable UFMT for torch/nn/ (#128865)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128865
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2024-07-22 14:58:12 +08:00
committed by PyTorch MergeBot
parent 8ea4c72eb2
commit b5c006acac
53 changed files with 517 additions and 459 deletions

View File

@ -1476,59 +1476,6 @@ exclude_patterns = [
'torch/linalg/__init__.py',
'torch/monitor/__init__.py',
'torch/nested/__init__.py',
'torch/nn/intrinsic/__init__.py',
'torch/nn/intrinsic/modules/__init__.py',
'torch/nn/intrinsic/modules/fused.py',
'torch/nn/intrinsic/qat/__init__.py',
'torch/nn/intrinsic/qat/modules/__init__.py',
'torch/nn/intrinsic/qat/modules/conv_fused.py',
'torch/nn/intrinsic/qat/modules/linear_fused.py',
'torch/nn/intrinsic/qat/modules/linear_relu.py',
'torch/nn/intrinsic/quantized/__init__.py',
'torch/nn/intrinsic/quantized/dynamic/__init__.py',
'torch/nn/intrinsic/quantized/dynamic/modules/__init__.py',
'torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py',
'torch/nn/intrinsic/quantized/modules/__init__.py',
'torch/nn/intrinsic/quantized/modules/bn_relu.py',
'torch/nn/intrinsic/quantized/modules/conv_relu.py',
'torch/nn/intrinsic/quantized/modules/linear_relu.py',
'torch/nn/qat/__init__.py',
'torch/nn/qat/dynamic/__init__.py',
'torch/nn/qat/dynamic/modules/__init__.py',
'torch/nn/qat/dynamic/modules/linear.py',
'torch/nn/qat/modules/__init__.py',
'torch/nn/qat/modules/conv.py',
'torch/nn/qat/modules/embedding_ops.py',
'torch/nn/qat/modules/linear.py',
'torch/nn/quantizable/__init__.py',
'torch/nn/quantizable/modules/__init__.py',
'torch/nn/quantizable/modules/activation.py',
'torch/nn/quantizable/modules/rnn.py',
'torch/nn/quantized/__init__.py',
'torch/nn/quantized/_reference/__init__.py',
'torch/nn/quantized/_reference/modules/__init__.py',
'torch/nn/quantized/_reference/modules/conv.py',
'torch/nn/quantized/_reference/modules/linear.py',
'torch/nn/quantized/_reference/modules/rnn.py',
'torch/nn/quantized/_reference/modules/sparse.py',
'torch/nn/quantized/_reference/modules/utils.py',
'torch/nn/quantized/dynamic/__init__.py',
'torch/nn/quantized/dynamic/modules/__init__.py',
'torch/nn/quantized/dynamic/modules/conv.py',
'torch/nn/quantized/dynamic/modules/linear.py',
'torch/nn/quantized/dynamic/modules/rnn.py',
'torch/nn/quantized/functional.py',
'torch/nn/quantized/modules/__init__.py',
'torch/nn/quantized/modules/activation.py',
'torch/nn/quantized/modules/batchnorm.py',
'torch/nn/quantized/modules/conv.py',
'torch/nn/quantized/modules/dropout.py',
'torch/nn/quantized/modules/embedding_ops.py',
'torch/nn/quantized/modules/functional_modules.py',
'torch/nn/quantized/modules/linear.py',
'torch/nn/quantized/modules/normalization.py',
'torch/nn/quantized/modules/rnn.py',
'torch/nn/quantized/modules/utils.py',
'torch/signal/__init__.py',
'torch/signal/windows/__init__.py',
'torch/signal/windows/windows.py',

View File

@ -1,35 +1,36 @@
from torch.ao.nn.intrinsic import ConvBn1d
from torch.ao.nn.intrinsic import ConvBn2d
from torch.ao.nn.intrinsic import ConvBn3d
from torch.ao.nn.intrinsic import ConvBnReLU1d
from torch.ao.nn.intrinsic import ConvBnReLU2d
from torch.ao.nn.intrinsic import ConvBnReLU3d
from torch.ao.nn.intrinsic import ConvReLU1d
from torch.ao.nn.intrinsic import ConvReLU2d
from torch.ao.nn.intrinsic import ConvReLU3d
from torch.ao.nn.intrinsic import LinearReLU
from torch.ao.nn.intrinsic import BNReLU2d
from torch.ao.nn.intrinsic import BNReLU3d
from torch.ao.nn.intrinsic import LinearBn1d
from torch.ao.nn.intrinsic import (
BNReLU2d,
BNReLU3d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearReLU,
)
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
# Include the subpackages in case user imports from it directly
from . import modules # noqa: F401
from . import qat # noqa: F401
from . import quantized # noqa: F401
from torch.nn.intrinsic import modules, qat, quantized # noqa: F401
__all__ = [
'ConvBn1d',
'ConvBn2d',
'ConvBn3d',
'ConvBnReLU1d',
'ConvBnReLU2d',
'ConvBnReLU3d',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'LinearReLU',
'BNReLU2d',
'BNReLU3d',
'LinearBn1d',
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
]

View File

@ -1,31 +1,33 @@
from .fused import _FusedModule # noqa: F401
from .fused import BNReLU2d
from .fused import BNReLU3d
from .fused import ConvBn1d
from .fused import ConvBn2d
from .fused import ConvBn3d
from .fused import ConvBnReLU1d
from .fused import ConvBnReLU2d
from .fused import ConvBnReLU3d
from .fused import ConvReLU1d
from .fused import ConvReLU2d
from .fused import ConvReLU3d
from .fused import LinearBn1d
from .fused import LinearReLU
from torch.nn.intrinsic.modules.fused import (
_FusedModule,
BNReLU2d,
BNReLU3d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearReLU,
)
__all__ = [
'BNReLU2d',
'BNReLU3d',
'ConvBn1d',
'ConvBn2d',
'ConvBn3d',
'ConvBnReLU1d',
'ConvBnReLU2d',
'ConvBnReLU3d',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'LinearBn1d',
'LinearReLU',
"BNReLU2d",
"BNReLU3d",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearBn1d",
"LinearReLU",
]

View File

@ -1,30 +1,33 @@
from torch.ao.nn.intrinsic import BNReLU2d
from torch.ao.nn.intrinsic import BNReLU3d
from torch.ao.nn.intrinsic import ConvBn1d
from torch.ao.nn.intrinsic import ConvBn2d
from torch.ao.nn.intrinsic import ConvBn3d
from torch.ao.nn.intrinsic import ConvBnReLU1d
from torch.ao.nn.intrinsic import ConvBnReLU2d
from torch.ao.nn.intrinsic import ConvBnReLU3d
from torch.ao.nn.intrinsic import ConvReLU1d
from torch.ao.nn.intrinsic import ConvReLU2d
from torch.ao.nn.intrinsic import ConvReLU3d
from torch.ao.nn.intrinsic import LinearBn1d
from torch.ao.nn.intrinsic import LinearReLU
from torch.ao.nn.intrinsic import (
BNReLU2d,
BNReLU3d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearReLU,
)
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
__all__ = [
'BNReLU2d',
'BNReLU3d',
'ConvBn1d',
'ConvBn2d',
'ConvBn3d',
'ConvBnReLU1d',
'ConvBnReLU2d',
'ConvBnReLU3d',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'LinearBn1d',
'LinearReLU',
"BNReLU2d",
"BNReLU3d",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearBn1d",
"LinearReLU",
]

View File

@ -1 +1 @@
from .modules import * # noqa: F403
from torch.nn.intrinsic.qat.modules import * # noqa: F403

View File

@ -1,6 +1,4 @@
from .linear_relu import LinearReLU
from .linear_fused import LinearBn1d
from .conv_fused import (
from torch.nn.intrinsic.qat.modules.conv_fused import (
ConvBn1d,
ConvBn2d,
ConvBn3d,
@ -10,9 +8,12 @@ from .conv_fused import (
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
update_bn_stats,
freeze_bn_stats,
update_bn_stats,
)
from torch.nn.intrinsic.qat.modules.linear_fused import LinearBn1d
from torch.nn.intrinsic.qat.modules.linear_relu import LinearReLU
__all__ = [
"LinearReLU",

View File

@ -8,30 +8,33 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here.
"""
from torch.ao.nn.intrinsic.qat import (
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
freeze_bn_stats,
update_bn_stats,
)
__all__ = [
# Modules
'ConvBn1d',
'ConvBnReLU1d',
'ConvReLU1d',
'ConvBn2d',
'ConvBnReLU2d',
'ConvReLU2d',
'ConvBn3d',
'ConvBnReLU3d',
'ConvReLU3d',
"ConvBn1d",
"ConvBnReLU1d",
"ConvReLU1d",
"ConvBn2d",
"ConvBnReLU2d",
"ConvReLU2d",
"ConvBn3d",
"ConvBnReLU3d",
"ConvReLU3d",
# Utilities
'freeze_bn_stats',
'update_bn_stats',
"freeze_bn_stats",
"update_bn_stats",
]
from torch.ao.nn.intrinsic.qat import ConvBn1d
from torch.ao.nn.intrinsic.qat import ConvBnReLU1d
from torch.ao.nn.intrinsic.qat import ConvReLU1d
from torch.ao.nn.intrinsic.qat import ConvBn2d
from torch.ao.nn.intrinsic.qat import ConvBnReLU2d
from torch.ao.nn.intrinsic.qat import ConvReLU2d
from torch.ao.nn.intrinsic.qat import ConvBn3d
from torch.ao.nn.intrinsic.qat import ConvBnReLU3d
from torch.ao.nn.intrinsic.qat import ConvReLU3d
from torch.ao.nn.intrinsic.qat import freeze_bn_stats
from torch.ao.nn.intrinsic.qat import update_bn_stats

View File

@ -8,8 +8,9 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here.
"""
__all__ = [
'LinearBn1d',
]
from torch.ao.nn.intrinsic.qat import LinearBn1d
__all__ = [
"LinearBn1d",
]

View File

@ -8,8 +8,9 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here.
"""
__all__ = [
'LinearReLU',
]
from torch.ao.nn.intrinsic.qat import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -1,13 +1,14 @@
from .modules import * # noqa: F403
# to ensure customers can use the module below
# without importing it directly
import torch.nn.intrinsic.quantized.dynamic
from torch.nn.intrinsic.quantized import dynamic, modules # noqa: F401
from torch.nn.intrinsic.quantized.modules import * # noqa: F403
__all__ = [
'BNReLU2d',
'BNReLU3d',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'LinearReLU',
"BNReLU2d",
"BNReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
]

View File

@ -1 +1 @@
from .modules import * # noqa: F403
from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403

View File

@ -1,5 +1,6 @@
from .linear_relu import LinearReLU
from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU
__all__ = [
'LinearReLU',
"LinearReLU",
]

View File

@ -1,5 +1,6 @@
from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
__all__ = [
'LinearReLU',
"LinearReLU",
]

View File

@ -1,12 +1,17 @@
from .linear_relu import LinearReLU
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
from .bn_relu import BNReLU2d, BNReLU3d
from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d
from torch.nn.intrinsic.quantized.modules.conv_relu import (
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
)
from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU
__all__ = [
'LinearReLU',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'BNReLU2d',
'BNReLU3d',
"LinearReLU",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"BNReLU2d",
"BNReLU3d",
]

View File

@ -1,7 +1,7 @@
from torch.ao.nn.intrinsic.quantized import BNReLU2d
from torch.ao.nn.intrinsic.quantized import BNReLU3d
from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d
__all__ = [
'BNReLU2d',
'BNReLU3d',
"BNReLU2d",
"BNReLU3d",
]

View File

@ -1,9 +1,8 @@
from torch.ao.nn.intrinsic.quantized import ConvReLU1d
from torch.ao.nn.intrinsic.quantized import ConvReLU2d
from torch.ao.nn.intrinsic.quantized import ConvReLU3d
from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d
__all__ = [
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
]

View File

@ -1,5 +1,6 @@
from torch.ao.nn.intrinsic.quantized import LinearReLU
__all__ = [
'LinearReLU',
"LinearReLU",
]

View File

@ -1,11 +1,11 @@
# mypy: allow-untyped-defs
from typing_extensions import deprecated
from .data_parallel import data_parallel, DataParallel
from .distributed import DistributedDataParallel
from .parallel_apply import parallel_apply
from .replicate import replicate
from .scatter_gather import gather, scatter
from torch.nn.parallel.data_parallel import data_parallel, DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import gather, scatter
__all__ = [

View File

@ -4,8 +4,7 @@ from typing import List, Optional
import torch
from torch._utils import _get_device_index
from torch.autograd import Function
from . import comm
from torch.nn.parallel import comm
class Broadcast(Function):

View File

@ -11,11 +11,10 @@ from torch._utils import (
_get_device_index,
_get_devices_properties,
)
from ..modules import Module
from .parallel_apply import parallel_apply
from .replicate import replicate
from .scatter_gather import gather, scatter_kwargs
from torch.nn.modules import Module
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
__all__ = ["DataParallel", "data_parallel"]

View File

@ -19,11 +19,10 @@ import torch.distributed as dist
from torch._utils import _get_device_index
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.nn.modules import Module
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
from torch.utils._pytree import tree_flatten, tree_unflatten
from ..modules import Module
from .scatter_gather import gather, scatter_kwargs
RPC_AVAILABLE = False
if dist.is_available():
@ -47,6 +46,7 @@ if dist.rpc.is_available():
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
__all__ = ["DistributedDataParallel"]
logger = logging.getLogger(__name__)

View File

@ -4,8 +4,7 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch._utils import ExceptionWrapper
from torch.cuda._utils import _get_device_index
from ..modules import Module
from torch.nn.modules import Module
__all__ = ["get_a_var", "parallel_apply"]

View File

@ -14,9 +14,8 @@ from typing import (
import torch
from torch._utils import _get_device_index
from ..modules import Module
from . import comm
from torch.nn.modules import Module
from torch.nn.parallel import comm
if TYPE_CHECKING:
@ -94,7 +93,7 @@ def _broadcast_coalesced_reshape(
devices: Sequence[Union[int, torch.device]],
detach: bool = False,
) -> List[List[torch.Tensor]]:
from ._functions import Broadcast
from torch.nn.parallel._functions import Broadcast
if detach:
return comm.broadcast_coalesced(tensors, devices)

View File

@ -3,8 +3,7 @@ from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar
from typing_extensions import deprecated
import torch
from ._functions import Gather, Scatter
from torch.nn.parallel._functions import Gather, Scatter
__all__ = ["scatter", "scatter_kwargs", "gather"]

View File

@ -4,9 +4,9 @@ r"""QAT Dynamic Modules.
This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.dynamic` instead.
"""
from . import dynamic # noqa: F403
from . import modules # noqa: F403
from .modules import * # noqa: F403
from torch.nn.qat import dynamic, modules # noqa: F403
from torch.nn.qat.modules import * # noqa: F403
__all__ = [
"Linear",

View File

@ -4,4 +4,4 @@ r"""QAT Dynamic Modules.
This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.dynamic` instead.
"""
from .modules import * # noqa: F403
from torch.nn.qat.dynamic.modules import * # noqa: F403

View File

@ -1,3 +1,4 @@
from .linear import Linear
from torch.nn.qat.dynamic.modules.linear import Linear
__all__ = ["Linear"]

View File

@ -4,15 +4,11 @@ r"""QAT Modules.
This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.modules` instead.
"""
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.qat.modules.linear import Linear
from torch.ao.nn.qat.modules.conv import Conv1d
from torch.ao.nn.qat.modules.conv import Conv2d
from torch.ao.nn.qat.modules.conv import Conv3d
from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag, Embedding
from torch.nn.qat.modules import conv, embedding_ops, linear
from . import conv
from . import embedding_ops
from . import linear
__all__ = [
"Linear",

View File

@ -7,6 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/qat/modules`,
while adding an import statement here.
"""
from torch.ao.nn.qat.modules.conv import Conv1d
from torch.ao.nn.qat.modules.conv import Conv2d
from torch.ao.nn.qat.modules.conv import Conv3d
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d

View File

@ -8,7 +8,7 @@ appropriate file under the `torch/ao/nn/qat/modules`,
while adding an import statement here.
"""
__all__ = ['Embedding', 'EmbeddingBag']
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.qat.modules.embedding_ops import Embedding
from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag
__all__ = ["Embedding", "EmbeddingBag"]

View File

@ -1 +1 @@
from .modules import * # noqa: F403
from torch.nn.quantizable.modules import * # noqa: F403

View File

@ -1,9 +1,9 @@
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
from torch.ao.nn.quantizable.modules.rnn import LSTM
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell
__all__ = [
'LSTM',
'LSTMCell',
'MultiheadAttention',
"LSTM",
"LSTMCell",
"MultiheadAttention",
]

View File

@ -7,5 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantizable/modules`,
while adding an import statement here.
"""
from torch.ao.nn.quantizable.modules.rnn import LSTM
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell

View File

@ -1,40 +1,39 @@
from . import dynamic # noqa: F403
from . import functional # noqa: F403
from . import modules # noqa: F403
from .modules import * # noqa: F403
from .modules import MaxPool2d
from torch.nn.quantized import dynamic, functional, modules # noqa: F403
from torch.nn.quantized.modules import * # noqa: F403
from torch.nn.quantized.modules import MaxPool2d
__all__ = [
'BatchNorm2d',
'BatchNorm3d',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'ConvTranspose3d',
'DeQuantize',
'Dropout',
'ELU',
'Embedding',
'EmbeddingBag',
'GroupNorm',
'Hardswish',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
'LayerNorm',
'LeakyReLU',
'Linear',
'LSTM',
'MultiheadAttention',
'PReLU',
'Quantize',
'ReLU6',
'Sigmoid',
'Softmax',
"BatchNorm2d",
"BatchNorm3d",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"DeQuantize",
"Dropout",
"ELU",
"Embedding",
"EmbeddingBag",
"GroupNorm",
"Hardswish",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LayerNorm",
"LeakyReLU",
"Linear",
"LSTM",
"MultiheadAttention",
"PReLU",
"Quantize",
"ReLU6",
"Sigmoid",
"Softmax",
# Wrapper modules
'FloatFunctional',
'FXFloatFunctional',
'QFunctional',
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]

View File

@ -1 +1 @@
from .modules import * # noqa: F403
from torch.nn.quantized._reference.modules import * # noqa: F403

View File

@ -9,23 +9,31 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.reference.modules.conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.reference.modules.linear import Linear
from torch.ao.nn.quantized.reference.modules.conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
from torch.ao.nn.quantized.reference.modules.rnn import RNNCell, LSTMCell, GRUCell, LSTM
from torch.ao.nn.quantized.reference.modules.rnn import GRUCell, LSTM, LSTMCell, RNNCell
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
__all__ = [
'Linear',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'ConvTranspose3d',
'RNNCell',
'LSTMCell',
'GRUCell',
'LSTM',
'Embedding',
'EmbeddingBag',
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"RNNCell",
"LSTMCell",
"GRUCell",
"LSTM",
"Embedding",
"EmbeddingBag",
]

View File

@ -9,11 +9,13 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.reference.modules.conv import _ConvNd
from torch.ao.nn.quantized.reference.modules.conv import Conv1d
from torch.ao.nn.quantized.reference.modules.conv import Conv2d
from torch.ao.nn.quantized.reference.modules.conv import Conv3d
from torch.ao.nn.quantized.reference.modules.conv import _ConvTransposeNd
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose1d
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose2d
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose3d
from torch.ao.nn.quantized.reference.modules.conv import (
_ConvNd,
_ConvTransposeNd,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)

View File

@ -9,9 +9,11 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.reference.modules.rnn import RNNCellBase
from torch.ao.nn.quantized.reference.modules.rnn import RNNCell
from torch.ao.nn.quantized.reference.modules.rnn import LSTMCell
from torch.ao.nn.quantized.reference.modules.rnn import GRUCell
from torch.ao.nn.quantized.reference.modules.rnn import RNNBase
from torch.ao.nn.quantized.reference.modules.rnn import LSTM
from torch.ao.nn.quantized.reference.modules.rnn import (
GRUCell,
LSTM,
LSTMCell,
RNNBase,
RNNCell,
RNNCellBase,
)

View File

@ -9,5 +9,4 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.reference.modules.sparse import Embedding
from torch.ao.nn.quantized.reference.modules.sparse import EmbeddingBag
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag

View File

@ -8,8 +8,11 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantized/reference`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.reference.modules.utils import _quantize_weight
from torch.ao.nn.quantized.reference.modules.utils import _quantize_and_dequantize_weight
from torch.ao.nn.quantized.reference.modules.utils import _save_weight_qparams
from torch.ao.nn.quantized.reference.modules.utils import _get_weight_qparam_keys
from torch.ao.nn.quantized.reference.modules.utils import ReferenceQuantizedModule
from torch.ao.nn.quantized.reference.modules.utils import (
_get_weight_qparam_keys,
_quantize_and_dequantize_weight,
_quantize_weight,
_save_weight_qparams,
ReferenceQuantizedModule,
)

View File

@ -8,25 +8,36 @@ appropriate file under the `torch/ao/nn/quantized/dynamic`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.dynamic.modules import conv
from torch.ao.nn.quantized.dynamic.modules import linear
from torch.ao.nn.quantized.dynamic.modules import rnn
from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
from torch.ao.nn.quantized.dynamic.modules import conv, linear, rnn
from torch.ao.nn.quantized.dynamic.modules.conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.dynamic.modules.linear import Linear
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM, GRU, LSTMCell, RNNCell, GRUCell
from torch.ao.nn.quantized.dynamic.modules.rnn import (
GRU,
GRUCell,
LSTM,
LSTMCell,
RNNCell,
)
__all__ = [
'Linear',
'LSTM',
'GRU',
'LSTMCell',
'RNNCell',
'GRUCell',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'ConvTranspose3d',
"Linear",
"LSTM",
"GRU",
"LSTMCell",
"RNNCell",
"GRUCell",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]

View File

@ -8,11 +8,21 @@ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
while adding an import statement here.
"""
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
from torch.ao.nn.quantized.dynamic.modules.conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d
from torch.ao.nn.quantized.dynamic.modules.conv import Conv2d
from torch.ao.nn.quantized.dynamic.modules.conv import Conv3d
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose1d
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose2d
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose3d
__all__ = [
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]

View File

@ -8,15 +8,27 @@ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
while adding an import statement here.
"""
__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell',
'GRUCell']
from torch.ao.nn.quantized.dynamic.modules.rnn import (
GRU,
GRUCell,
LSTM,
LSTMCell,
pack_weight_bias,
PackedParameter,
RNNBase,
RNNCell,
RNNCellBase,
)
from torch.ao.nn.quantized.dynamic.modules.rnn import pack_weight_bias
from torch.ao.nn.quantized.dynamic.modules.rnn import PackedParameter
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNBase
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM
from torch.ao.nn.quantized.dynamic.modules.rnn import GRU
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCellBase
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCell
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTMCell
from torch.ao.nn.quantized.dynamic.modules.rnn import GRUCell
__all__ = [
"pack_weight_bias",
"PackedParameter",
"RNNBase",
"LSTM",
"GRU",
"RNNCellBase",
"RNNCell",
"LSTMCell",
"GRUCell",
]

View File

@ -5,66 +5,93 @@ Note::
Please, use `torch.ao.nn.quantized` instead.
"""
from torch.ao.nn.quantized.modules.activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
from torch.ao.nn.quantized.modules.conv import Conv1d, Conv2d, Conv3d
from torch.ao.nn.quantized.modules.conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
from torch.ao.nn.quantized.modules.dropout import Dropout
from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional, FXFloatFunctional, QFunctional
from torch.ao.nn.quantized.modules.linear import Linear
from torch.ao.nn.quantized.modules.normalization import LayerNorm, GroupNorm, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
from torch.ao.nn.quantized.modules.rnn import LSTM
from torch.ao.nn.quantized.modules import MaxPool2d
from torch.ao.nn.quantized.modules import Quantize, DeQuantize
# The following imports are needed in case the user decides
# to import the files directly,
# s.a. `from torch.nn.quantized.modules.conv import ...`.
# No need to add them to the `__all__`.
from torch.ao.nn.quantized.modules import activation
from torch.ao.nn.quantized.modules import batchnorm
from torch.ao.nn.quantized.modules import conv
from torch.ao.nn.quantized.modules import dropout
from torch.ao.nn.quantized.modules import embedding_ops
from torch.ao.nn.quantized.modules import functional_modules
from torch.ao.nn.quantized.modules import linear
from torch.ao.nn.quantized.modules import normalization
from torch.ao.nn.quantized.modules import rnn
from torch.ao.nn.quantized.modules import utils
from torch.ao.nn.quantized.modules import (
activation,
batchnorm,
conv,
DeQuantize,
dropout,
embedding_ops,
functional_modules,
linear,
MaxPool2d,
normalization,
Quantize,
rnn,
utils,
)
from torch.ao.nn.quantized.modules.activation import (
ELU,
Hardswish,
LeakyReLU,
MultiheadAttention,
PReLU,
ReLU6,
Sigmoid,
Softmax,
)
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
from torch.ao.nn.quantized.modules.conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.modules.dropout import Dropout
from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.quantized.modules.functional_modules import (
FloatFunctional,
FXFloatFunctional,
QFunctional,
)
from torch.ao.nn.quantized.modules.linear import Linear
from torch.ao.nn.quantized.modules.normalization import (
GroupNorm,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LayerNorm,
)
from torch.ao.nn.quantized.modules.rnn import LSTM
__all__ = [
'BatchNorm2d',
'BatchNorm3d',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'ConvTranspose3d',
'DeQuantize',
'ELU',
'Embedding',
'EmbeddingBag',
'GroupNorm',
'Hardswish',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
'LayerNorm',
'LeakyReLU',
'Linear',
'LSTM',
'MultiheadAttention',
'Quantize',
'ReLU6',
'Sigmoid',
'Softmax',
'Dropout',
'PReLU',
"BatchNorm2d",
"BatchNorm3d",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"DeQuantize",
"ELU",
"Embedding",
"EmbeddingBag",
"GroupNorm",
"Hardswish",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LayerNorm",
"LeakyReLU",
"Linear",
"LSTM",
"MultiheadAttention",
"Quantize",
"ReLU6",
"Sigmoid",
"Softmax",
"Dropout",
"PReLU",
# Wrapper modules
'FloatFunctional',
'FXFloatFunctional',
'QFunctional',
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]

View File

@ -8,11 +8,13 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.modules.activation import ELU
from torch.ao.nn.quantized.modules.activation import Hardswish
from torch.ao.nn.quantized.modules.activation import LeakyReLU
from torch.ao.nn.quantized.modules.activation import MultiheadAttention
from torch.ao.nn.quantized.modules.activation import PReLU
from torch.ao.nn.quantized.modules.activation import ReLU6
from torch.ao.nn.quantized.modules.activation import Sigmoid
from torch.ao.nn.quantized.modules.activation import Softmax
from torch.ao.nn.quantized.modules.activation import (
ELU,
Hardswish,
LeakyReLU,
MultiheadAttention,
PReLU,
ReLU6,
Sigmoid,
Softmax,
)

View File

@ -8,5 +8,4 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm3d
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d

View File

@ -8,14 +8,22 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
from torch.ao.nn.quantized.modules.conv import (
_reverse_repeat_padding,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
from torch.ao.nn.quantized.modules.conv import Conv1d
from torch.ao.nn.quantized.modules.conv import Conv2d
from torch.ao.nn.quantized.modules.conv import Conv3d
from torch.ao.nn.quantized.modules.conv import ConvTranspose1d
from torch.ao.nn.quantized.modules.conv import ConvTranspose2d
from torch.ao.nn.quantized.modules.conv import ConvTranspose3d
__all__ = [
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]

View File

@ -8,6 +8,7 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
__all__ = ['Dropout']
from torch.ao.nn.quantized.modules.dropout import Dropout
__all__ = ["Dropout"]

View File

@ -8,8 +8,11 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
__all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag']
from torch.ao.nn.quantized.modules.embedding_ops import (
Embedding,
EmbeddingBag,
EmbeddingPackedParams,
)
from torch.ao.nn.quantized.modules.embedding_ops import Embedding
from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingBag
from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingPackedParams
__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]

View File

@ -8,8 +8,11 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional']
from torch.ao.nn.quantized.modules.functional_modules import (
FloatFunctional,
FXFloatFunctional,
QFunctional,
)
from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional
from torch.ao.nn.quantized.modules.functional_modules import FXFloatFunctional
from torch.ao.nn.quantized.modules.functional_modules import QFunctional
__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]

View File

@ -8,7 +8,7 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
__all__ = ['LinearPackedParams', 'Linear']
from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams
from torch.ao.nn.quantized.modules.linear import Linear
from torch.ao.nn.quantized.modules.linear import LinearPackedParams
__all__ = ["LinearPackedParams", "Linear"]

View File

@ -8,10 +8,19 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
from torch.ao.nn.quantized.modules.normalization import (
GroupNorm,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LayerNorm,
)
from torch.ao.nn.quantized.modules.normalization import LayerNorm
from torch.ao.nn.quantized.modules.normalization import GroupNorm
from torch.ao.nn.quantized.modules.normalization import InstanceNorm1d
from torch.ao.nn.quantized.modules.normalization import InstanceNorm2d
from torch.ao.nn.quantized.modules.normalization import InstanceNorm3d
__all__ = [
"LayerNorm",
"GroupNorm",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
]

View File

@ -8,8 +8,10 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.modules.utils import _ntuple_from_first
from torch.ao.nn.quantized.modules.utils import _pair_from_first
from torch.ao.nn.quantized.modules.utils import _quantize_weight
from torch.ao.nn.quantized.modules.utils import _hide_packed_params_repr
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
from torch.ao.nn.quantized.modules.utils import (
_hide_packed_params_repr,
_ntuple_from_first,
_pair_from_first,
_quantize_weight,
WeightedQuantizedModule,
)