mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] improve submodule discovery for torch.ao type annotations (#144680)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144680 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
c40d917182
commit
bee84e88f8
@ -1,17 +1,29 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# torch.ao is a package with a lot of interdependencies.
|
||||
# We will use lazy import to avoid cyclic dependencies here.
|
||||
|
||||
from typing import TYPE_CHECKING as _TYPE_CHECKING
|
||||
|
||||
|
||||
if _TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
from torch.ao import ( # noqa: TC004
|
||||
nn as nn,
|
||||
ns as ns,
|
||||
pruning as pruning,
|
||||
quantization as quantization,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"nn",
|
||||
"ns",
|
||||
"quantization",
|
||||
"pruning",
|
||||
"quantization",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
def __getattr__(name: str) -> "ModuleType":
|
||||
if name in __all__:
|
||||
import importlib
|
||||
|
||||
|
||||
@ -1,10 +1,21 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# We are exposing all subpackages to the end-user.
|
||||
# Because of possible inter-dependency, we want to avoid
|
||||
# the cyclic imports, thus implementing lazy version
|
||||
# as per https://peps.python.org/pep-0562/
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING as _TYPE_CHECKING
|
||||
|
||||
|
||||
if _TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
from torch.ao.nn import ( # noqa: TC004
|
||||
intrinsic as intrinsic,
|
||||
qat as qat,
|
||||
quantizable as quantizable,
|
||||
quantized as quantized,
|
||||
sparse as sparse,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -16,7 +27,9 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
def __getattr__(name: str) -> "ModuleType":
|
||||
if name in __all__:
|
||||
import importlib
|
||||
|
||||
return importlib.import_module("." + name, __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -49,7 +49,7 @@ class LSTM(torch.ao.nn.quantizable.LSTM):
|
||||
|
||||
@classmethod
|
||||
def from_observed(cls, other):
|
||||
assert isinstance(other, cls._FLOAT_MODULE)
|
||||
assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type]
|
||||
converted = torch.ao.quantization.convert(
|
||||
other, inplace=False, remove_qconfig=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user