mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable UFMT for torch/nn/
(#128865)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128865 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8ea4c72eb2
commit
b5c006acac
@ -1476,59 +1476,6 @@ exclude_patterns = [
|
|||||||
'torch/linalg/__init__.py',
|
'torch/linalg/__init__.py',
|
||||||
'torch/monitor/__init__.py',
|
'torch/monitor/__init__.py',
|
||||||
'torch/nested/__init__.py',
|
'torch/nested/__init__.py',
|
||||||
'torch/nn/intrinsic/__init__.py',
|
|
||||||
'torch/nn/intrinsic/modules/__init__.py',
|
|
||||||
'torch/nn/intrinsic/modules/fused.py',
|
|
||||||
'torch/nn/intrinsic/qat/__init__.py',
|
|
||||||
'torch/nn/intrinsic/qat/modules/__init__.py',
|
|
||||||
'torch/nn/intrinsic/qat/modules/conv_fused.py',
|
|
||||||
'torch/nn/intrinsic/qat/modules/linear_fused.py',
|
|
||||||
'torch/nn/intrinsic/qat/modules/linear_relu.py',
|
|
||||||
'torch/nn/intrinsic/quantized/__init__.py',
|
|
||||||
'torch/nn/intrinsic/quantized/dynamic/__init__.py',
|
|
||||||
'torch/nn/intrinsic/quantized/dynamic/modules/__init__.py',
|
|
||||||
'torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py',
|
|
||||||
'torch/nn/intrinsic/quantized/modules/__init__.py',
|
|
||||||
'torch/nn/intrinsic/quantized/modules/bn_relu.py',
|
|
||||||
'torch/nn/intrinsic/quantized/modules/conv_relu.py',
|
|
||||||
'torch/nn/intrinsic/quantized/modules/linear_relu.py',
|
|
||||||
'torch/nn/qat/__init__.py',
|
|
||||||
'torch/nn/qat/dynamic/__init__.py',
|
|
||||||
'torch/nn/qat/dynamic/modules/__init__.py',
|
|
||||||
'torch/nn/qat/dynamic/modules/linear.py',
|
|
||||||
'torch/nn/qat/modules/__init__.py',
|
|
||||||
'torch/nn/qat/modules/conv.py',
|
|
||||||
'torch/nn/qat/modules/embedding_ops.py',
|
|
||||||
'torch/nn/qat/modules/linear.py',
|
|
||||||
'torch/nn/quantizable/__init__.py',
|
|
||||||
'torch/nn/quantizable/modules/__init__.py',
|
|
||||||
'torch/nn/quantizable/modules/activation.py',
|
|
||||||
'torch/nn/quantizable/modules/rnn.py',
|
|
||||||
'torch/nn/quantized/__init__.py',
|
|
||||||
'torch/nn/quantized/_reference/__init__.py',
|
|
||||||
'torch/nn/quantized/_reference/modules/__init__.py',
|
|
||||||
'torch/nn/quantized/_reference/modules/conv.py',
|
|
||||||
'torch/nn/quantized/_reference/modules/linear.py',
|
|
||||||
'torch/nn/quantized/_reference/modules/rnn.py',
|
|
||||||
'torch/nn/quantized/_reference/modules/sparse.py',
|
|
||||||
'torch/nn/quantized/_reference/modules/utils.py',
|
|
||||||
'torch/nn/quantized/dynamic/__init__.py',
|
|
||||||
'torch/nn/quantized/dynamic/modules/__init__.py',
|
|
||||||
'torch/nn/quantized/dynamic/modules/conv.py',
|
|
||||||
'torch/nn/quantized/dynamic/modules/linear.py',
|
|
||||||
'torch/nn/quantized/dynamic/modules/rnn.py',
|
|
||||||
'torch/nn/quantized/functional.py',
|
|
||||||
'torch/nn/quantized/modules/__init__.py',
|
|
||||||
'torch/nn/quantized/modules/activation.py',
|
|
||||||
'torch/nn/quantized/modules/batchnorm.py',
|
|
||||||
'torch/nn/quantized/modules/conv.py',
|
|
||||||
'torch/nn/quantized/modules/dropout.py',
|
|
||||||
'torch/nn/quantized/modules/embedding_ops.py',
|
|
||||||
'torch/nn/quantized/modules/functional_modules.py',
|
|
||||||
'torch/nn/quantized/modules/linear.py',
|
|
||||||
'torch/nn/quantized/modules/normalization.py',
|
|
||||||
'torch/nn/quantized/modules/rnn.py',
|
|
||||||
'torch/nn/quantized/modules/utils.py',
|
|
||||||
'torch/signal/__init__.py',
|
'torch/signal/__init__.py',
|
||||||
'torch/signal/windows/__init__.py',
|
'torch/signal/windows/__init__.py',
|
||||||
'torch/signal/windows/windows.py',
|
'torch/signal/windows/windows.py',
|
||||||
|
@ -1,35 +1,36 @@
|
|||||||
from torch.ao.nn.intrinsic import ConvBn1d
|
from torch.ao.nn.intrinsic import (
|
||||||
from torch.ao.nn.intrinsic import ConvBn2d
|
BNReLU2d,
|
||||||
from torch.ao.nn.intrinsic import ConvBn3d
|
BNReLU3d,
|
||||||
from torch.ao.nn.intrinsic import ConvBnReLU1d
|
ConvBn1d,
|
||||||
from torch.ao.nn.intrinsic import ConvBnReLU2d
|
ConvBn2d,
|
||||||
from torch.ao.nn.intrinsic import ConvBnReLU3d
|
ConvBn3d,
|
||||||
from torch.ao.nn.intrinsic import ConvReLU1d
|
ConvBnReLU1d,
|
||||||
from torch.ao.nn.intrinsic import ConvReLU2d
|
ConvBnReLU2d,
|
||||||
from torch.ao.nn.intrinsic import ConvReLU3d
|
ConvBnReLU3d,
|
||||||
from torch.ao.nn.intrinsic import LinearReLU
|
ConvReLU1d,
|
||||||
from torch.ao.nn.intrinsic import BNReLU2d
|
ConvReLU2d,
|
||||||
from torch.ao.nn.intrinsic import BNReLU3d
|
ConvReLU3d,
|
||||||
from torch.ao.nn.intrinsic import LinearBn1d
|
LinearBn1d,
|
||||||
|
LinearReLU,
|
||||||
|
)
|
||||||
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
|
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
|
||||||
|
|
||||||
# Include the subpackages in case user imports from it directly
|
# Include the subpackages in case user imports from it directly
|
||||||
from . import modules # noqa: F401
|
from torch.nn.intrinsic import modules, qat, quantized # noqa: F401
|
||||||
from . import qat # noqa: F401
|
|
||||||
from . import quantized # noqa: F401
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ConvBn1d',
|
"ConvBn1d",
|
||||||
'ConvBn2d',
|
"ConvBn2d",
|
||||||
'ConvBn3d',
|
"ConvBn3d",
|
||||||
'ConvBnReLU1d',
|
"ConvBnReLU1d",
|
||||||
'ConvBnReLU2d',
|
"ConvBnReLU2d",
|
||||||
'ConvBnReLU3d',
|
"ConvBnReLU3d",
|
||||||
'ConvReLU1d',
|
"ConvReLU1d",
|
||||||
'ConvReLU2d',
|
"ConvReLU2d",
|
||||||
'ConvReLU3d',
|
"ConvReLU3d",
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
'BNReLU2d',
|
"BNReLU2d",
|
||||||
'BNReLU3d',
|
"BNReLU3d",
|
||||||
'LinearBn1d',
|
"LinearBn1d",
|
||||||
]
|
]
|
||||||
|
@ -1,31 +1,33 @@
|
|||||||
from .fused import _FusedModule # noqa: F401
|
from torch.nn.intrinsic.modules.fused import (
|
||||||
from .fused import BNReLU2d
|
_FusedModule,
|
||||||
from .fused import BNReLU3d
|
BNReLU2d,
|
||||||
from .fused import ConvBn1d
|
BNReLU3d,
|
||||||
from .fused import ConvBn2d
|
ConvBn1d,
|
||||||
from .fused import ConvBn3d
|
ConvBn2d,
|
||||||
from .fused import ConvBnReLU1d
|
ConvBn3d,
|
||||||
from .fused import ConvBnReLU2d
|
ConvBnReLU1d,
|
||||||
from .fused import ConvBnReLU3d
|
ConvBnReLU2d,
|
||||||
from .fused import ConvReLU1d
|
ConvBnReLU3d,
|
||||||
from .fused import ConvReLU2d
|
ConvReLU1d,
|
||||||
from .fused import ConvReLU3d
|
ConvReLU2d,
|
||||||
from .fused import LinearBn1d
|
ConvReLU3d,
|
||||||
from .fused import LinearReLU
|
LinearBn1d,
|
||||||
|
LinearReLU,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BNReLU2d',
|
"BNReLU2d",
|
||||||
'BNReLU3d',
|
"BNReLU3d",
|
||||||
'ConvBn1d',
|
"ConvBn1d",
|
||||||
'ConvBn2d',
|
"ConvBn2d",
|
||||||
'ConvBn3d',
|
"ConvBn3d",
|
||||||
'ConvBnReLU1d',
|
"ConvBnReLU1d",
|
||||||
'ConvBnReLU2d',
|
"ConvBnReLU2d",
|
||||||
'ConvBnReLU3d',
|
"ConvBnReLU3d",
|
||||||
'ConvReLU1d',
|
"ConvReLU1d",
|
||||||
'ConvReLU2d',
|
"ConvReLU2d",
|
||||||
'ConvReLU3d',
|
"ConvReLU3d",
|
||||||
'LinearBn1d',
|
"LinearBn1d",
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
]
|
]
|
||||||
|
@ -1,30 +1,33 @@
|
|||||||
from torch.ao.nn.intrinsic import BNReLU2d
|
from torch.ao.nn.intrinsic import (
|
||||||
from torch.ao.nn.intrinsic import BNReLU3d
|
BNReLU2d,
|
||||||
from torch.ao.nn.intrinsic import ConvBn1d
|
BNReLU3d,
|
||||||
from torch.ao.nn.intrinsic import ConvBn2d
|
ConvBn1d,
|
||||||
from torch.ao.nn.intrinsic import ConvBn3d
|
ConvBn2d,
|
||||||
from torch.ao.nn.intrinsic import ConvBnReLU1d
|
ConvBn3d,
|
||||||
from torch.ao.nn.intrinsic import ConvBnReLU2d
|
ConvBnReLU1d,
|
||||||
from torch.ao.nn.intrinsic import ConvBnReLU3d
|
ConvBnReLU2d,
|
||||||
from torch.ao.nn.intrinsic import ConvReLU1d
|
ConvBnReLU3d,
|
||||||
from torch.ao.nn.intrinsic import ConvReLU2d
|
ConvReLU1d,
|
||||||
from torch.ao.nn.intrinsic import ConvReLU3d
|
ConvReLU2d,
|
||||||
from torch.ao.nn.intrinsic import LinearBn1d
|
ConvReLU3d,
|
||||||
from torch.ao.nn.intrinsic import LinearReLU
|
LinearBn1d,
|
||||||
|
LinearReLU,
|
||||||
|
)
|
||||||
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
|
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BNReLU2d',
|
"BNReLU2d",
|
||||||
'BNReLU3d',
|
"BNReLU3d",
|
||||||
'ConvBn1d',
|
"ConvBn1d",
|
||||||
'ConvBn2d',
|
"ConvBn2d",
|
||||||
'ConvBn3d',
|
"ConvBn3d",
|
||||||
'ConvBnReLU1d',
|
"ConvBnReLU1d",
|
||||||
'ConvBnReLU2d',
|
"ConvBnReLU2d",
|
||||||
'ConvBnReLU3d',
|
"ConvBnReLU3d",
|
||||||
'ConvReLU1d',
|
"ConvReLU1d",
|
||||||
'ConvReLU2d',
|
"ConvReLU2d",
|
||||||
'ConvReLU3d',
|
"ConvReLU3d",
|
||||||
'LinearBn1d',
|
"LinearBn1d",
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
]
|
]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .modules import * # noqa: F403
|
from torch.nn.intrinsic.qat.modules import * # noqa: F403
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
from .linear_relu import LinearReLU
|
from torch.nn.intrinsic.qat.modules.conv_fused import (
|
||||||
from .linear_fused import LinearBn1d
|
|
||||||
from .conv_fused import (
|
|
||||||
ConvBn1d,
|
ConvBn1d,
|
||||||
ConvBn2d,
|
ConvBn2d,
|
||||||
ConvBn3d,
|
ConvBn3d,
|
||||||
@ -10,9 +8,12 @@ from .conv_fused import (
|
|||||||
ConvReLU1d,
|
ConvReLU1d,
|
||||||
ConvReLU2d,
|
ConvReLU2d,
|
||||||
ConvReLU3d,
|
ConvReLU3d,
|
||||||
update_bn_stats,
|
|
||||||
freeze_bn_stats,
|
freeze_bn_stats,
|
||||||
|
update_bn_stats,
|
||||||
)
|
)
|
||||||
|
from torch.nn.intrinsic.qat.modules.linear_fused import LinearBn1d
|
||||||
|
from torch.nn.intrinsic.qat.modules.linear_relu import LinearReLU
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LinearReLU",
|
"LinearReLU",
|
||||||
|
@ -8,30 +8,33 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from torch.ao.nn.intrinsic.qat import (
|
||||||
|
ConvBn1d,
|
||||||
|
ConvBn2d,
|
||||||
|
ConvBn3d,
|
||||||
|
ConvBnReLU1d,
|
||||||
|
ConvBnReLU2d,
|
||||||
|
ConvBnReLU3d,
|
||||||
|
ConvReLU1d,
|
||||||
|
ConvReLU2d,
|
||||||
|
ConvReLU3d,
|
||||||
|
freeze_bn_stats,
|
||||||
|
update_bn_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Modules
|
# Modules
|
||||||
'ConvBn1d',
|
"ConvBn1d",
|
||||||
'ConvBnReLU1d',
|
"ConvBnReLU1d",
|
||||||
'ConvReLU1d',
|
"ConvReLU1d",
|
||||||
'ConvBn2d',
|
"ConvBn2d",
|
||||||
'ConvBnReLU2d',
|
"ConvBnReLU2d",
|
||||||
'ConvReLU2d',
|
"ConvReLU2d",
|
||||||
'ConvBn3d',
|
"ConvBn3d",
|
||||||
'ConvBnReLU3d',
|
"ConvBnReLU3d",
|
||||||
'ConvReLU3d',
|
"ConvReLU3d",
|
||||||
# Utilities
|
# Utilities
|
||||||
'freeze_bn_stats',
|
"freeze_bn_stats",
|
||||||
'update_bn_stats',
|
"update_bn_stats",
|
||||||
]
|
]
|
||||||
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBn1d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBnReLU1d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvReLU1d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBn2d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBnReLU2d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvReLU2d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBn3d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBnReLU3d
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvReLU3d
|
|
||||||
from torch.ao.nn.intrinsic.qat import freeze_bn_stats
|
|
||||||
from torch.ao.nn.intrinsic.qat import update_bn_stats
|
|
||||||
|
@ -8,8 +8,9 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'LinearBn1d',
|
|
||||||
]
|
|
||||||
|
|
||||||
from torch.ao.nn.intrinsic.qat import LinearBn1d
|
from torch.ao.nn.intrinsic.qat import LinearBn1d
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LinearBn1d",
|
||||||
|
]
|
||||||
|
@ -8,8 +8,9 @@ appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'LinearReLU',
|
|
||||||
]
|
|
||||||
|
|
||||||
from torch.ao.nn.intrinsic.qat import LinearReLU
|
from torch.ao.nn.intrinsic.qat import LinearReLU
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LinearReLU",
|
||||||
|
]
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
from .modules import * # noqa: F403
|
|
||||||
# to ensure customers can use the module below
|
# to ensure customers can use the module below
|
||||||
# without importing it directly
|
# without importing it directly
|
||||||
import torch.nn.intrinsic.quantized.dynamic
|
from torch.nn.intrinsic.quantized import dynamic, modules # noqa: F401
|
||||||
|
from torch.nn.intrinsic.quantized.modules import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BNReLU2d',
|
"BNReLU2d",
|
||||||
'BNReLU3d',
|
"BNReLU3d",
|
||||||
'ConvReLU1d',
|
"ConvReLU1d",
|
||||||
'ConvReLU2d',
|
"ConvReLU2d",
|
||||||
'ConvReLU3d',
|
"ConvReLU3d",
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
]
|
]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .modules import * # noqa: F403
|
from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from .linear_relu import LinearReLU
|
from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
|
from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
]
|
]
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
from .linear_relu import LinearReLU
|
from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d
|
||||||
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
|
from torch.nn.intrinsic.quantized.modules.conv_relu import (
|
||||||
from .bn_relu import BNReLU2d, BNReLU3d
|
ConvReLU1d,
|
||||||
|
ConvReLU2d,
|
||||||
|
ConvReLU3d,
|
||||||
|
)
|
||||||
|
from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
'ConvReLU1d',
|
"ConvReLU1d",
|
||||||
'ConvReLU2d',
|
"ConvReLU2d",
|
||||||
'ConvReLU3d',
|
"ConvReLU3d",
|
||||||
'BNReLU2d',
|
"BNReLU2d",
|
||||||
'BNReLU3d',
|
"BNReLU3d",
|
||||||
]
|
]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from torch.ao.nn.intrinsic.quantized import BNReLU2d
|
from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d
|
||||||
from torch.ao.nn.intrinsic.quantized import BNReLU3d
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BNReLU2d',
|
"BNReLU2d",
|
||||||
'BNReLU3d',
|
"BNReLU3d",
|
||||||
]
|
]
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from torch.ao.nn.intrinsic.quantized import ConvReLU1d
|
from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d
|
||||||
from torch.ao.nn.intrinsic.quantized import ConvReLU2d
|
|
||||||
from torch.ao.nn.intrinsic.quantized import ConvReLU3d
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ConvReLU1d',
|
"ConvReLU1d",
|
||||||
'ConvReLU2d',
|
"ConvReLU2d",
|
||||||
'ConvReLU3d',
|
"ConvReLU3d",
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from torch.ao.nn.intrinsic.quantized import LinearReLU
|
from torch.ao.nn.intrinsic.quantized import LinearReLU
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LinearReLU',
|
"LinearReLU",
|
||||||
]
|
]
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from .data_parallel import data_parallel, DataParallel
|
from torch.nn.parallel.data_parallel import data_parallel, DataParallel
|
||||||
from .distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
from .parallel_apply import parallel_apply
|
from torch.nn.parallel.parallel_apply import parallel_apply
|
||||||
from .replicate import replicate
|
from torch.nn.parallel.replicate import replicate
|
||||||
from .scatter_gather import gather, scatter
|
from torch.nn.parallel.scatter_gather import gather, scatter
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -4,8 +4,7 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch._utils import _get_device_index
|
from torch._utils import _get_device_index
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
|
from torch.nn.parallel import comm
|
||||||
from . import comm
|
|
||||||
|
|
||||||
|
|
||||||
class Broadcast(Function):
|
class Broadcast(Function):
|
||||||
|
@ -11,11 +11,10 @@ from torch._utils import (
|
|||||||
_get_device_index,
|
_get_device_index,
|
||||||
_get_devices_properties,
|
_get_devices_properties,
|
||||||
)
|
)
|
||||||
|
from torch.nn.modules import Module
|
||||||
from ..modules import Module
|
from torch.nn.parallel.parallel_apply import parallel_apply
|
||||||
from .parallel_apply import parallel_apply
|
from torch.nn.parallel.replicate import replicate
|
||||||
from .replicate import replicate
|
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
|
||||||
from .scatter_gather import gather, scatter_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["DataParallel", "data_parallel"]
|
__all__ = ["DataParallel", "data_parallel"]
|
||||||
|
@ -19,11 +19,10 @@ import torch.distributed as dist
|
|||||||
from torch._utils import _get_device_index
|
from torch._utils import _get_device_index
|
||||||
from torch.autograd import Function, Variable
|
from torch.autograd import Function, Variable
|
||||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||||
|
from torch.nn.modules import Module
|
||||||
|
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
|
||||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
from ..modules import Module
|
|
||||||
from .scatter_gather import gather, scatter_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
RPC_AVAILABLE = False
|
RPC_AVAILABLE = False
|
||||||
if dist.is_available():
|
if dist.is_available():
|
||||||
@ -47,6 +46,7 @@ if dist.rpc.is_available():
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["DistributedDataParallel"]
|
__all__ = ["DistributedDataParallel"]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -4,8 +4,7 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch._utils import ExceptionWrapper
|
from torch._utils import ExceptionWrapper
|
||||||
from torch.cuda._utils import _get_device_index
|
from torch.cuda._utils import _get_device_index
|
||||||
|
from torch.nn.modules import Module
|
||||||
from ..modules import Module
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_a_var", "parallel_apply"]
|
__all__ = ["get_a_var", "parallel_apply"]
|
||||||
|
@ -14,9 +14,8 @@ from typing import (
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._utils import _get_device_index
|
from torch._utils import _get_device_index
|
||||||
|
from torch.nn.modules import Module
|
||||||
from ..modules import Module
|
from torch.nn.parallel import comm
|
||||||
from . import comm
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -94,7 +93,7 @@ def _broadcast_coalesced_reshape(
|
|||||||
devices: Sequence[Union[int, torch.device]],
|
devices: Sequence[Union[int, torch.device]],
|
||||||
detach: bool = False,
|
detach: bool = False,
|
||||||
) -> List[List[torch.Tensor]]:
|
) -> List[List[torch.Tensor]]:
|
||||||
from ._functions import Broadcast
|
from torch.nn.parallel._functions import Broadcast
|
||||||
|
|
||||||
if detach:
|
if detach:
|
||||||
return comm.broadcast_coalesced(tensors, devices)
|
return comm.broadcast_coalesced(tensors, devices)
|
||||||
|
@ -3,8 +3,7 @@ from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar
|
|||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.parallel._functions import Gather, Scatter
|
||||||
from ._functions import Gather, Scatter
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["scatter", "scatter_kwargs", "gather"]
|
__all__ = ["scatter", "scatter_kwargs", "gather"]
|
||||||
|
@ -4,9 +4,9 @@ r"""QAT Dynamic Modules.
|
|||||||
This package is in the process of being deprecated.
|
This package is in the process of being deprecated.
|
||||||
Please, use `torch.ao.nn.qat.dynamic` instead.
|
Please, use `torch.ao.nn.qat.dynamic` instead.
|
||||||
"""
|
"""
|
||||||
from . import dynamic # noqa: F403
|
from torch.nn.qat import dynamic, modules # noqa: F403
|
||||||
from . import modules # noqa: F403
|
from torch.nn.qat.modules import * # noqa: F403
|
||||||
from .modules import * # noqa: F403
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Linear",
|
"Linear",
|
||||||
|
@ -4,4 +4,4 @@ r"""QAT Dynamic Modules.
|
|||||||
This package is in the process of being deprecated.
|
This package is in the process of being deprecated.
|
||||||
Please, use `torch.ao.nn.qat.dynamic` instead.
|
Please, use `torch.ao.nn.qat.dynamic` instead.
|
||||||
"""
|
"""
|
||||||
from .modules import * # noqa: F403
|
from torch.nn.qat.dynamic.modules import * # noqa: F403
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from .linear import Linear
|
from torch.nn.qat.dynamic.modules.linear import Linear
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Linear"]
|
__all__ = ["Linear"]
|
||||||
|
@ -4,15 +4,11 @@ r"""QAT Modules.
|
|||||||
This package is in the process of being deprecated.
|
This package is in the process of being deprecated.
|
||||||
Please, use `torch.ao.nn.qat.modules` instead.
|
Please, use `torch.ao.nn.qat.modules` instead.
|
||||||
"""
|
"""
|
||||||
|
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
|
||||||
|
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
|
||||||
from torch.ao.nn.qat.modules.linear import Linear
|
from torch.ao.nn.qat.modules.linear import Linear
|
||||||
from torch.ao.nn.qat.modules.conv import Conv1d
|
from torch.nn.qat.modules import conv, embedding_ops, linear
|
||||||
from torch.ao.nn.qat.modules.conv import Conv2d
|
|
||||||
from torch.ao.nn.qat.modules.conv import Conv3d
|
|
||||||
from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag, Embedding
|
|
||||||
|
|
||||||
from . import conv
|
|
||||||
from . import embedding_ops
|
|
||||||
from . import linear
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Linear",
|
"Linear",
|
||||||
|
@ -7,6 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
|
|||||||
appropriate file under the `torch/ao/nn/qat/modules`,
|
appropriate file under the `torch/ao/nn/qat/modules`,
|
||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
from torch.ao.nn.qat.modules.conv import Conv1d
|
|
||||||
from torch.ao.nn.qat.modules.conv import Conv2d
|
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
|
||||||
from torch.ao.nn.qat.modules.conv import Conv3d
|
|
||||||
|
@ -8,7 +8,7 @@ appropriate file under the `torch/ao/nn/qat/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['Embedding', 'EmbeddingBag']
|
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
|
||||||
|
|
||||||
from torch.ao.nn.qat.modules.embedding_ops import Embedding
|
|
||||||
from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag
|
__all__ = ["Embedding", "EmbeddingBag"]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .modules import * # noqa: F403
|
from torch.nn.quantizable.modules import * # noqa: F403
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
|
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
|
||||||
from torch.ao.nn.quantizable.modules.rnn import LSTM
|
from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell
|
||||||
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LSTM',
|
"LSTM",
|
||||||
'LSTMCell',
|
"LSTMCell",
|
||||||
'MultiheadAttention',
|
"MultiheadAttention",
|
||||||
]
|
]
|
||||||
|
@ -7,5 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
|
|||||||
appropriate file under the `torch/ao/nn/quantizable/modules`,
|
appropriate file under the `torch/ao/nn/quantizable/modules`,
|
||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
from torch.ao.nn.quantizable.modules.rnn import LSTM
|
|
||||||
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
|
from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell
|
||||||
|
@ -1,40 +1,39 @@
|
|||||||
from . import dynamic # noqa: F403
|
from torch.nn.quantized import dynamic, functional, modules # noqa: F403
|
||||||
from . import functional # noqa: F403
|
from torch.nn.quantized.modules import * # noqa: F403
|
||||||
from . import modules # noqa: F403
|
from torch.nn.quantized.modules import MaxPool2d
|
||||||
from .modules import * # noqa: F403
|
|
||||||
from .modules import MaxPool2d
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BatchNorm2d',
|
"BatchNorm2d",
|
||||||
'BatchNorm3d',
|
"BatchNorm3d",
|
||||||
'Conv1d',
|
"Conv1d",
|
||||||
'Conv2d',
|
"Conv2d",
|
||||||
'Conv3d',
|
"Conv3d",
|
||||||
'ConvTranspose1d',
|
"ConvTranspose1d",
|
||||||
'ConvTranspose2d',
|
"ConvTranspose2d",
|
||||||
'ConvTranspose3d',
|
"ConvTranspose3d",
|
||||||
'DeQuantize',
|
"DeQuantize",
|
||||||
'Dropout',
|
"Dropout",
|
||||||
'ELU',
|
"ELU",
|
||||||
'Embedding',
|
"Embedding",
|
||||||
'EmbeddingBag',
|
"EmbeddingBag",
|
||||||
'GroupNorm',
|
"GroupNorm",
|
||||||
'Hardswish',
|
"Hardswish",
|
||||||
'InstanceNorm1d',
|
"InstanceNorm1d",
|
||||||
'InstanceNorm2d',
|
"InstanceNorm2d",
|
||||||
'InstanceNorm3d',
|
"InstanceNorm3d",
|
||||||
'LayerNorm',
|
"LayerNorm",
|
||||||
'LeakyReLU',
|
"LeakyReLU",
|
||||||
'Linear',
|
"Linear",
|
||||||
'LSTM',
|
"LSTM",
|
||||||
'MultiheadAttention',
|
"MultiheadAttention",
|
||||||
'PReLU',
|
"PReLU",
|
||||||
'Quantize',
|
"Quantize",
|
||||||
'ReLU6',
|
"ReLU6",
|
||||||
'Sigmoid',
|
"Sigmoid",
|
||||||
'Softmax',
|
"Softmax",
|
||||||
# Wrapper modules
|
# Wrapper modules
|
||||||
'FloatFunctional',
|
"FloatFunctional",
|
||||||
'FXFloatFunctional',
|
"FXFloatFunctional",
|
||||||
'QFunctional',
|
"QFunctional",
|
||||||
]
|
]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .modules import * # noqa: F403
|
from torch.nn.quantized._reference.modules import * # noqa: F403
|
||||||
|
@ -9,23 +9,31 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from torch.ao.nn.quantized.reference.modules.conv import (
|
||||||
|
Conv1d,
|
||||||
|
Conv2d,
|
||||||
|
Conv3d,
|
||||||
|
ConvTranspose1d,
|
||||||
|
ConvTranspose2d,
|
||||||
|
ConvTranspose3d,
|
||||||
|
)
|
||||||
from torch.ao.nn.quantized.reference.modules.linear import Linear
|
from torch.ao.nn.quantized.reference.modules.linear import Linear
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
from torch.ao.nn.quantized.reference.modules.rnn import GRUCell, LSTM, LSTMCell, RNNCell
|
||||||
from torch.ao.nn.quantized.reference.modules.rnn import RNNCell, LSTMCell, GRUCell, LSTM
|
|
||||||
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
|
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Linear',
|
"Linear",
|
||||||
'Conv1d',
|
"Conv1d",
|
||||||
'Conv2d',
|
"Conv2d",
|
||||||
'Conv3d',
|
"Conv3d",
|
||||||
'ConvTranspose1d',
|
"ConvTranspose1d",
|
||||||
'ConvTranspose2d',
|
"ConvTranspose2d",
|
||||||
'ConvTranspose3d',
|
"ConvTranspose3d",
|
||||||
'RNNCell',
|
"RNNCell",
|
||||||
'LSTMCell',
|
"LSTMCell",
|
||||||
'GRUCell',
|
"GRUCell",
|
||||||
'LSTM',
|
"LSTM",
|
||||||
'Embedding',
|
"Embedding",
|
||||||
'EmbeddingBag',
|
"EmbeddingBag",
|
||||||
]
|
]
|
||||||
|
@ -9,11 +9,13 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import _ConvNd
|
from torch.ao.nn.quantized.reference.modules.conv import (
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import Conv1d
|
_ConvNd,
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import Conv2d
|
_ConvTransposeNd,
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import Conv3d
|
Conv1d,
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import _ConvTransposeNd
|
Conv2d,
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose1d
|
Conv3d,
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose2d
|
ConvTranspose1d,
|
||||||
from torch.ao.nn.quantized.reference.modules.conv import ConvTranspose3d
|
ConvTranspose2d,
|
||||||
|
ConvTranspose3d,
|
||||||
|
)
|
||||||
|
@ -9,9 +9,11 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.reference.modules.rnn import RNNCellBase
|
from torch.ao.nn.quantized.reference.modules.rnn import (
|
||||||
from torch.ao.nn.quantized.reference.modules.rnn import RNNCell
|
GRUCell,
|
||||||
from torch.ao.nn.quantized.reference.modules.rnn import LSTMCell
|
LSTM,
|
||||||
from torch.ao.nn.quantized.reference.modules.rnn import GRUCell
|
LSTMCell,
|
||||||
from torch.ao.nn.quantized.reference.modules.rnn import RNNBase
|
RNNBase,
|
||||||
from torch.ao.nn.quantized.reference.modules.rnn import LSTM
|
RNNCell,
|
||||||
|
RNNCellBase,
|
||||||
|
)
|
||||||
|
@ -9,5 +9,4 @@ appropriate file under the `torch/ao/nn/quantized/reference`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.reference.modules.sparse import Embedding
|
from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag
|
||||||
from torch.ao.nn.quantized.reference.modules.sparse import EmbeddingBag
|
|
||||||
|
@ -8,8 +8,11 @@ If you are adding a new entry/functionality, please, add it to the
|
|||||||
appropriate file under the `torch/ao/nn/quantized/reference`,
|
appropriate file under the `torch/ao/nn/quantized/reference`,
|
||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
from torch.ao.nn.quantized.reference.modules.utils import _quantize_weight
|
|
||||||
from torch.ao.nn.quantized.reference.modules.utils import _quantize_and_dequantize_weight
|
from torch.ao.nn.quantized.reference.modules.utils import (
|
||||||
from torch.ao.nn.quantized.reference.modules.utils import _save_weight_qparams
|
_get_weight_qparam_keys,
|
||||||
from torch.ao.nn.quantized.reference.modules.utils import _get_weight_qparam_keys
|
_quantize_and_dequantize_weight,
|
||||||
from torch.ao.nn.quantized.reference.modules.utils import ReferenceQuantizedModule
|
_quantize_weight,
|
||||||
|
_save_weight_qparams,
|
||||||
|
ReferenceQuantizedModule,
|
||||||
|
)
|
||||||
|
@ -8,25 +8,36 @@ appropriate file under the `torch/ao/nn/quantized/dynamic`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.dynamic.modules import conv
|
from torch.ao.nn.quantized.dynamic.modules import conv, linear, rnn
|
||||||
from torch.ao.nn.quantized.dynamic.modules import linear
|
from torch.ao.nn.quantized.dynamic.modules.conv import (
|
||||||
from torch.ao.nn.quantized.dynamic.modules import rnn
|
Conv1d,
|
||||||
|
Conv2d,
|
||||||
from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
Conv3d,
|
||||||
|
ConvTranspose1d,
|
||||||
|
ConvTranspose2d,
|
||||||
|
ConvTranspose3d,
|
||||||
|
)
|
||||||
from torch.ao.nn.quantized.dynamic.modules.linear import Linear
|
from torch.ao.nn.quantized.dynamic.modules.linear import Linear
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM, GRU, LSTMCell, RNNCell, GRUCell
|
from torch.ao.nn.quantized.dynamic.modules.rnn import (
|
||||||
|
GRU,
|
||||||
|
GRUCell,
|
||||||
|
LSTM,
|
||||||
|
LSTMCell,
|
||||||
|
RNNCell,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Linear',
|
"Linear",
|
||||||
'LSTM',
|
"LSTM",
|
||||||
'GRU',
|
"GRU",
|
||||||
'LSTMCell',
|
"LSTMCell",
|
||||||
'RNNCell',
|
"RNNCell",
|
||||||
'GRUCell',
|
"GRUCell",
|
||||||
'Conv1d',
|
"Conv1d",
|
||||||
'Conv2d',
|
"Conv2d",
|
||||||
'Conv3d',
|
"Conv3d",
|
||||||
'ConvTranspose1d',
|
"ConvTranspose1d",
|
||||||
'ConvTranspose2d',
|
"ConvTranspose2d",
|
||||||
'ConvTranspose3d',
|
"ConvTranspose3d",
|
||||||
]
|
]
|
||||||
|
@ -8,11 +8,21 @@ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
|
from torch.ao.nn.quantized.dynamic.modules.conv import (
|
||||||
|
Conv1d,
|
||||||
|
Conv2d,
|
||||||
|
Conv3d,
|
||||||
|
ConvTranspose1d,
|
||||||
|
ConvTranspose2d,
|
||||||
|
ConvTranspose3d,
|
||||||
|
)
|
||||||
|
|
||||||
from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d
|
|
||||||
from torch.ao.nn.quantized.dynamic.modules.conv import Conv2d
|
__all__ = [
|
||||||
from torch.ao.nn.quantized.dynamic.modules.conv import Conv3d
|
"Conv1d",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose1d
|
"Conv2d",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose2d
|
"Conv3d",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose3d
|
"ConvTranspose1d",
|
||||||
|
"ConvTranspose2d",
|
||||||
|
"ConvTranspose3d",
|
||||||
|
]
|
||||||
|
@ -8,15 +8,27 @@ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell',
|
from torch.ao.nn.quantized.dynamic.modules.rnn import (
|
||||||
'GRUCell']
|
GRU,
|
||||||
|
GRUCell,
|
||||||
|
LSTM,
|
||||||
|
LSTMCell,
|
||||||
|
pack_weight_bias,
|
||||||
|
PackedParameter,
|
||||||
|
RNNBase,
|
||||||
|
RNNCell,
|
||||||
|
RNNCellBase,
|
||||||
|
)
|
||||||
|
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import pack_weight_bias
|
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import PackedParameter
|
__all__ = [
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNBase
|
"pack_weight_bias",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM
|
"PackedParameter",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import GRU
|
"RNNBase",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCellBase
|
"LSTM",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCell
|
"GRU",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTMCell
|
"RNNCellBase",
|
||||||
from torch.ao.nn.quantized.dynamic.modules.rnn import GRUCell
|
"RNNCell",
|
||||||
|
"LSTMCell",
|
||||||
|
"GRUCell",
|
||||||
|
]
|
||||||
|
@ -5,66 +5,93 @@ Note::
|
|||||||
Please, use `torch.ao.nn.quantized` instead.
|
Please, use `torch.ao.nn.quantized` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU
|
|
||||||
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
|
|
||||||
from torch.ao.nn.quantized.modules.conv import Conv1d, Conv2d, Conv3d
|
|
||||||
from torch.ao.nn.quantized.modules.conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
|
||||||
from torch.ao.nn.quantized.modules.dropout import Dropout
|
|
||||||
from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
|
|
||||||
from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional, FXFloatFunctional, QFunctional
|
|
||||||
from torch.ao.nn.quantized.modules.linear import Linear
|
|
||||||
from torch.ao.nn.quantized.modules.normalization import LayerNorm, GroupNorm, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
|
|
||||||
from torch.ao.nn.quantized.modules.rnn import LSTM
|
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules import MaxPool2d
|
|
||||||
from torch.ao.nn.quantized.modules import Quantize, DeQuantize
|
|
||||||
|
|
||||||
# The following imports are needed in case the user decides
|
# The following imports are needed in case the user decides
|
||||||
# to import the files directly,
|
# to import the files directly,
|
||||||
# s.a. `from torch.nn.quantized.modules.conv import ...`.
|
# s.a. `from torch.nn.quantized.modules.conv import ...`.
|
||||||
# No need to add them to the `__all__`.
|
# No need to add them to the `__all__`.
|
||||||
from torch.ao.nn.quantized.modules import activation
|
from torch.ao.nn.quantized.modules import (
|
||||||
from torch.ao.nn.quantized.modules import batchnorm
|
activation,
|
||||||
from torch.ao.nn.quantized.modules import conv
|
batchnorm,
|
||||||
from torch.ao.nn.quantized.modules import dropout
|
conv,
|
||||||
from torch.ao.nn.quantized.modules import embedding_ops
|
DeQuantize,
|
||||||
from torch.ao.nn.quantized.modules import functional_modules
|
dropout,
|
||||||
from torch.ao.nn.quantized.modules import linear
|
embedding_ops,
|
||||||
from torch.ao.nn.quantized.modules import normalization
|
functional_modules,
|
||||||
from torch.ao.nn.quantized.modules import rnn
|
linear,
|
||||||
from torch.ao.nn.quantized.modules import utils
|
MaxPool2d,
|
||||||
|
normalization,
|
||||||
|
Quantize,
|
||||||
|
rnn,
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
from torch.ao.nn.quantized.modules.activation import (
|
||||||
|
ELU,
|
||||||
|
Hardswish,
|
||||||
|
LeakyReLU,
|
||||||
|
MultiheadAttention,
|
||||||
|
PReLU,
|
||||||
|
ReLU6,
|
||||||
|
Sigmoid,
|
||||||
|
Softmax,
|
||||||
|
)
|
||||||
|
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
|
||||||
|
from torch.ao.nn.quantized.modules.conv import (
|
||||||
|
Conv1d,
|
||||||
|
Conv2d,
|
||||||
|
Conv3d,
|
||||||
|
ConvTranspose1d,
|
||||||
|
ConvTranspose2d,
|
||||||
|
ConvTranspose3d,
|
||||||
|
)
|
||||||
|
from torch.ao.nn.quantized.modules.dropout import Dropout
|
||||||
|
from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
|
||||||
|
from torch.ao.nn.quantized.modules.functional_modules import (
|
||||||
|
FloatFunctional,
|
||||||
|
FXFloatFunctional,
|
||||||
|
QFunctional,
|
||||||
|
)
|
||||||
|
from torch.ao.nn.quantized.modules.linear import Linear
|
||||||
|
from torch.ao.nn.quantized.modules.normalization import (
|
||||||
|
GroupNorm,
|
||||||
|
InstanceNorm1d,
|
||||||
|
InstanceNorm2d,
|
||||||
|
InstanceNorm3d,
|
||||||
|
LayerNorm,
|
||||||
|
)
|
||||||
|
from torch.ao.nn.quantized.modules.rnn import LSTM
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BatchNorm2d',
|
"BatchNorm2d",
|
||||||
'BatchNorm3d',
|
"BatchNorm3d",
|
||||||
'Conv1d',
|
"Conv1d",
|
||||||
'Conv2d',
|
"Conv2d",
|
||||||
'Conv3d',
|
"Conv3d",
|
||||||
'ConvTranspose1d',
|
"ConvTranspose1d",
|
||||||
'ConvTranspose2d',
|
"ConvTranspose2d",
|
||||||
'ConvTranspose3d',
|
"ConvTranspose3d",
|
||||||
'DeQuantize',
|
"DeQuantize",
|
||||||
'ELU',
|
"ELU",
|
||||||
'Embedding',
|
"Embedding",
|
||||||
'EmbeddingBag',
|
"EmbeddingBag",
|
||||||
'GroupNorm',
|
"GroupNorm",
|
||||||
'Hardswish',
|
"Hardswish",
|
||||||
'InstanceNorm1d',
|
"InstanceNorm1d",
|
||||||
'InstanceNorm2d',
|
"InstanceNorm2d",
|
||||||
'InstanceNorm3d',
|
"InstanceNorm3d",
|
||||||
'LayerNorm',
|
"LayerNorm",
|
||||||
'LeakyReLU',
|
"LeakyReLU",
|
||||||
'Linear',
|
"Linear",
|
||||||
'LSTM',
|
"LSTM",
|
||||||
'MultiheadAttention',
|
"MultiheadAttention",
|
||||||
'Quantize',
|
"Quantize",
|
||||||
'ReLU6',
|
"ReLU6",
|
||||||
'Sigmoid',
|
"Sigmoid",
|
||||||
'Softmax',
|
"Softmax",
|
||||||
'Dropout',
|
"Dropout",
|
||||||
'PReLU',
|
"PReLU",
|
||||||
# Wrapper modules
|
# Wrapper modules
|
||||||
'FloatFunctional',
|
"FloatFunctional",
|
||||||
'FXFloatFunctional',
|
"FXFloatFunctional",
|
||||||
'QFunctional',
|
"QFunctional",
|
||||||
]
|
]
|
||||||
|
@ -8,11 +8,13 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.activation import ELU
|
from torch.ao.nn.quantized.modules.activation import (
|
||||||
from torch.ao.nn.quantized.modules.activation import Hardswish
|
ELU,
|
||||||
from torch.ao.nn.quantized.modules.activation import LeakyReLU
|
Hardswish,
|
||||||
from torch.ao.nn.quantized.modules.activation import MultiheadAttention
|
LeakyReLU,
|
||||||
from torch.ao.nn.quantized.modules.activation import PReLU
|
MultiheadAttention,
|
||||||
from torch.ao.nn.quantized.modules.activation import ReLU6
|
PReLU,
|
||||||
from torch.ao.nn.quantized.modules.activation import Sigmoid
|
ReLU6,
|
||||||
from torch.ao.nn.quantized.modules.activation import Softmax
|
Sigmoid,
|
||||||
|
Softmax,
|
||||||
|
)
|
||||||
|
@ -8,5 +8,4 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d
|
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
|
||||||
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm3d
|
|
||||||
|
@ -8,14 +8,22 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
|
from torch.ao.nn.quantized.modules.conv import (
|
||||||
|
_reverse_repeat_padding,
|
||||||
|
Conv1d,
|
||||||
|
Conv2d,
|
||||||
|
Conv3d,
|
||||||
|
ConvTranspose1d,
|
||||||
|
ConvTranspose2d,
|
||||||
|
ConvTranspose3d,
|
||||||
|
)
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
|
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.conv import Conv1d
|
__all__ = [
|
||||||
from torch.ao.nn.quantized.modules.conv import Conv2d
|
"Conv1d",
|
||||||
from torch.ao.nn.quantized.modules.conv import Conv3d
|
"Conv2d",
|
||||||
|
"Conv3d",
|
||||||
from torch.ao.nn.quantized.modules.conv import ConvTranspose1d
|
"ConvTranspose1d",
|
||||||
from torch.ao.nn.quantized.modules.conv import ConvTranspose2d
|
"ConvTranspose2d",
|
||||||
from torch.ao.nn.quantized.modules.conv import ConvTranspose3d
|
"ConvTranspose3d",
|
||||||
|
]
|
||||||
|
@ -8,6 +8,7 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['Dropout']
|
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.dropout import Dropout
|
from torch.ao.nn.quantized.modules.dropout import Dropout
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Dropout"]
|
||||||
|
@ -8,8 +8,11 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag']
|
from torch.ao.nn.quantized.modules.embedding_ops import (
|
||||||
|
Embedding,
|
||||||
|
EmbeddingBag,
|
||||||
|
EmbeddingPackedParams,
|
||||||
|
)
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.embedding_ops import Embedding
|
|
||||||
from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingBag
|
__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
|
||||||
from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingPackedParams
|
|
||||||
|
@ -8,8 +8,11 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional']
|
from torch.ao.nn.quantized.modules.functional_modules import (
|
||||||
|
FloatFunctional,
|
||||||
|
FXFloatFunctional,
|
||||||
|
QFunctional,
|
||||||
|
)
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.functional_modules import FloatFunctional
|
|
||||||
from torch.ao.nn.quantized.modules.functional_modules import FXFloatFunctional
|
__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]
|
||||||
from torch.ao.nn.quantized.modules.functional_modules import QFunctional
|
|
||||||
|
@ -8,7 +8,7 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['LinearPackedParams', 'Linear']
|
from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.linear import Linear
|
|
||||||
from torch.ao.nn.quantized.modules.linear import LinearPackedParams
|
__all__ = ["LinearPackedParams", "Linear"]
|
||||||
|
@ -8,10 +8,19 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
|
from torch.ao.nn.quantized.modules.normalization import (
|
||||||
|
GroupNorm,
|
||||||
|
InstanceNorm1d,
|
||||||
|
InstanceNorm2d,
|
||||||
|
InstanceNorm3d,
|
||||||
|
LayerNorm,
|
||||||
|
)
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.normalization import LayerNorm
|
|
||||||
from torch.ao.nn.quantized.modules.normalization import GroupNorm
|
__all__ = [
|
||||||
from torch.ao.nn.quantized.modules.normalization import InstanceNorm1d
|
"LayerNorm",
|
||||||
from torch.ao.nn.quantized.modules.normalization import InstanceNorm2d
|
"GroupNorm",
|
||||||
from torch.ao.nn.quantized.modules.normalization import InstanceNorm3d
|
"InstanceNorm1d",
|
||||||
|
"InstanceNorm2d",
|
||||||
|
"InstanceNorm3d",
|
||||||
|
]
|
||||||
|
@ -8,8 +8,10 @@ appropriate file under the `torch/ao/nn/quantized/modules`,
|
|||||||
while adding an import statement here.
|
while adding an import statement here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.ao.nn.quantized.modules.utils import _ntuple_from_first
|
from torch.ao.nn.quantized.modules.utils import (
|
||||||
from torch.ao.nn.quantized.modules.utils import _pair_from_first
|
_hide_packed_params_repr,
|
||||||
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
_ntuple_from_first,
|
||||||
from torch.ao.nn.quantized.modules.utils import _hide_packed_params_repr
|
_pair_from_first,
|
||||||
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
|
_quantize_weight,
|
||||||
|
WeightedQuantizedModule,
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user