Add torch.dtype instances to the public API (#119307)

Fixes #91908

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119307
Approved by: https://github.com/albanD
This commit is contained in:
Tamir Cohen
2024-02-07 02:57:45 +00:00
committed by PyTorch MergeBot
parent 8c2fde1fcf
commit 45a79323fe
2 changed files with 11 additions and 1 deletions

View File

@ -460,7 +460,8 @@ class TestPublicBindings(TestCase):
# verifies that each public API has the correct module name and naming semantics
def check_one_element(elem, modname, mod, *, is_public, is_all):
obj = getattr(mod, elem)
if not (isinstance(obj, Callable) or inspect.isclass(obj)):
# torch.dtype is not a class nor callable, so we need to check for it separately
if not (isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)):
return
elem_module = getattr(obj, '__module__', None)
# Only used for nice error message below

View File

@ -1513,6 +1513,15 @@ for name in dir(_C._VariableFunctions):
__all__.append(name)
################################################################################
# Add torch.dtype instances to the public API
################################################################################
import torch
for attribute in dir(torch):
if isinstance(getattr(torch, attribute), torch.dtype):
__all__.append(attribute)
################################################################################
# Import TorchDynamo's lazy APIs to avoid circular dependenices