diff --git a/torch/_classes.py b/torch/_classes.py index 069f13dcb679..a811c7c30be6 100644 --- a/torch/_classes.py +++ b/torch/_classes.py @@ -1,15 +1,15 @@ -# mypy: allow-untyped-defs import types +from typing import Any import torch._C class _ClassNamespace(types.ModuleType): - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__("torch.classes" + name) self.name = name - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: proxy = torch._C._get_custom_class_python_wrapper(self.name, attr) if proxy is None: raise RuntimeError(f"Class {self.name}.{attr} not registered!") @@ -22,16 +22,16 @@ class _Classes(types.ModuleType): def __init__(self) -> None: super().__init__("torch.classes") - def __getattr__(self, name): + def __getattr__(self, name: str) -> _ClassNamespace: namespace = _ClassNamespace(name) setattr(self, name, namespace) return namespace @property - def loaded_libraries(self): + def loaded_libraries(self) -> Any: return torch.ops.loaded_libraries - def load_library(self, path): + def load_library(self, path: str) -> None: """ Loads a shared library from the given path into the current process.