[BE][Easy] improve submodule discovery for torch.ao type annotations (#144680)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144680
Approved by: https://github.com/Skylion007
This commit is contained in:
Xuehai Pan
2025-01-13 21:39:17 +08:00
committed by PyTorch MergeBot
parent c40d917182
commit bee84e88f8
3 changed files with 32 additions and 7 deletions

View File

@ -1,17 +1,29 @@
# mypy: allow-untyped-defs
# torch.ao is a package with a lot of interdependencies.
# We will use lazy import to avoid cyclic dependencies here.
from typing import TYPE_CHECKING as _TYPE_CHECKING
if _TYPE_CHECKING:
from types import ModuleType
from torch.ao import ( # noqa: TC004
nn as nn,
ns as ns,
pruning as pruning,
quantization as quantization,
)
__all__ = [
"nn",
"ns",
"quantization",
"pruning",
"quantization",
]
def __getattr__(name):
def __getattr__(name: str) -> "ModuleType":
if name in __all__:
import importlib

View File

@ -1,10 +1,21 @@
# mypy: allow-untyped-defs
# We are exposing all subpackages to the end-user.
# Because of possible inter-dependency, we want to avoid
# the cyclic imports, thus implementing lazy version
# as per https://peps.python.org/pep-0562/
import importlib
from typing import TYPE_CHECKING as _TYPE_CHECKING
if _TYPE_CHECKING:
from types import ModuleType
from torch.ao.nn import ( # noqa: TC004
intrinsic as intrinsic,
qat as qat,
quantizable as quantizable,
quantized as quantized,
sparse as sparse,
)
__all__ = [
@ -16,7 +27,9 @@ __all__ = [
]
def __getattr__(name):
def __getattr__(name: str) -> "ModuleType":
if name in __all__:
import importlib
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -49,7 +49,7 @@ class LSTM(torch.ao.nn.quantizable.LSTM):
@classmethod
def from_observed(cls, other):
assert isinstance(other, cls._FLOAT_MODULE)
assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type]
converted = torch.ao.quantization.convert(
other, inplace=False, remove_qconfig=True
)