mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18579 ghimport-source-id: 65124c95e49423de4ad1008c65e75057fea09b94 Differential Revision: D14778507 Pulled By: ezyang fbshipit-source-id: 1e4bb79f4800a116ce8fb7af2fefbd34da8d102c
91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
import torch
|
|
import unittest
|
|
import os
|
|
import re
|
|
import textwrap
|
|
|
|
|
|
path = os.path.dirname(os.path.realpath(__file__))
|
|
rstpath = os.path.join(path, '../docs/source/')
|
|
pypath = os.path.join(path, '../torch/_torch_docs.py')
|
|
r1 = re.compile(r'\.\. autofunction:: (\w*)')
|
|
r2 = re.compile(r'\.\. auto(?:method|attribute):: (\w*)')
|
|
|
|
|
|
class TestDocCoverage(unittest.TestCase):
|
|
|
|
@staticmethod
|
|
def parse_rst(filename, regex):
|
|
filename = os.path.join(rstpath, filename)
|
|
ret = set()
|
|
with open(filename, 'r') as f:
|
|
lines = f.readlines()
|
|
for l in lines:
|
|
l = l.strip()
|
|
name = regex.findall(l)
|
|
if name:
|
|
ret.add(name[0])
|
|
return ret
|
|
|
|
def test_torch(self):
|
|
# TODO: The algorithm here is kind of unsound; we don't assume
|
|
# every identifier in torch.rst lives in torch by virtue of
|
|
# where it lives; instead, it lives in torch because at the
|
|
# beginning of the file we specified automodule. This means
|
|
# that this script can get confused if you have, e.g., multiple
|
|
# automodule directives in the torch file. "Don't do that."
|
|
# (Or fix this to properly handle that case.)
|
|
|
|
# get symbols documented in torch.rst
|
|
in_rst = self.parse_rst('torch.rst', r1)
|
|
# get symbols in functional.py and _torch_docs.py
|
|
whitelist = {
|
|
# below are some jit functions
|
|
'wait', 'fork', 'parse_type_comment', 'import_ir_module',
|
|
'import_ir_module_from_buffer', 'merge_type_from_type_comment',
|
|
'parse_ir',
|
|
|
|
# below are symbols mistakely binded to torch.*, but should
|
|
# go to torch.nn.functional.* instead
|
|
'avg_pool1d', 'conv_transpose2d', 'conv_transpose1d', 'conv3d',
|
|
'relu_', 'pixel_shuffle', 'conv2d', 'selu_', 'celu_', 'threshold_',
|
|
'cosine_similarity', 'rrelu_', 'conv_transpose3d', 'conv1d', 'pdist',
|
|
'adaptive_avg_pool1d', 'conv_tbc'
|
|
}
|
|
has_docstring = set(
|
|
a for a in dir(torch)
|
|
if getattr(torch, a).__doc__ and not a.startswith('_') and
|
|
'function' in type(getattr(torch, a)).__name__)
|
|
self.assertEqual(
|
|
has_docstring & whitelist, whitelist,
|
|
textwrap.dedent('''
|
|
The whitelist in test_docs_coverage.py contains something
|
|
that don't have docstring or not in torch.*. If you just
|
|
removed something from torch.*, please remove it from whiltelist
|
|
in test_docs_coverage.py'''))
|
|
has_docstring -= whitelist
|
|
# assert they are equal
|
|
self.assertEqual(
|
|
has_docstring, in_rst,
|
|
textwrap.dedent('''
|
|
List of functions documented in torch.rst and in python are different.
|
|
Do you forget to add new thing to torch.rst, or whitelist things you
|
|
don't want to document?''')
|
|
)
|
|
|
|
def test_tensor(self):
|
|
in_rst = self.parse_rst('tensors.rst', r2)
|
|
classes = [torch.FloatTensor, torch.LongTensor, torch.ByteTensor]
|
|
has_docstring = set(x for c in classes for x in dir(c) if not x.startswith('_') and getattr(c, x).__doc__)
|
|
self.assertEqual(
|
|
has_docstring, in_rst,
|
|
textwrap.dedent('''
|
|
List of tensor methods documented in tensor.rst and in python are
|
|
different. Do you forget to add new thing to tensor.rst, or whitelist
|
|
things you don't want to document?''')
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|