Bring docstring to .pyi file (#114705)

Fixes #37762

Since the original issue hasn't been making progress for more than 3 years, I am attempting to make this PR to at least make some progress forward.

This PR attempts to add docstring to the `.pyi` files. The docstrings are read from [`_torch_docs`](https://github.com/pytorch/pytorch/blob/main/torch/_torch_docs.py) by mocking [`_add_docstr`](9f073ae304/torch/csrc/Module.cpp (L329)), which is the only function used to add docstring.

Luckily, `_torch_docs` has no dependencies for other components of PyTorch, and can be imported without compiling `torch._C` with `_add_docstr` mocked.

The generated `.pyi` file looks something like the following:

[_VariableFunctions.pyi.txt](https://github.com/pytorch/pytorch/files/13494263/_VariableFunctions.pyi.txt)

<img width="787" alt="image" src="https://github.com/pytorch/pytorch/assets/6421097/73c2e884-f06b-4529-8301-0ca0b9de173c">

And the docstring can be picked up by VSCode:

<img width="839" alt="image" src="https://github.com/pytorch/pytorch/assets/6421097/1999dc89-a591-4c7a-80ac-aa3456672af4">

<img width="908" alt="image" src="https://github.com/pytorch/pytorch/assets/6421097/ecf3fa92-9822-4a3d-9263-d224d87ac288">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114705
Approved by: https://github.com/albanD
This commit is contained in:
Shawn Zhong
2024-01-09 18:37:12 +00:00
committed by PyTorch MergeBot
parent cfd0728b24
commit 0dd5deeced
3 changed files with 50 additions and 2 deletions

View File

@ -1,7 +1,11 @@
import argparse
import collections
import importlib
import sys
from pprint import pformat
from typing import Dict, List, Sequence
from unittest.mock import Mock, patch
from torchgen.api.python import (
PythonSignatureGroup,
@ -580,6 +584,43 @@ def gen_nn_functional(fm: FileManager) -> None:
)
"""
We gather the docstrings for torch with the following steps:
1. Mock torch and torch._C, which are the only dependencies of the docs files
2. Mock the _add_docstr function to save the docstrings
3. Import the docs files to trigger mocked _add_docstr and collect docstrings
"""
def gather_docstrs() -> Dict[str, str]:
docstrs = {}
def mock_add_docstr(func: Mock, docstr: str) -> None:
docstrs[func._extract_mock_name()] = docstr.strip()
# sys.modules and sys.path are restored after the context manager exits
with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]):
# mock the torch module and torch._C._add_docstr
sys.modules["torch"] = Mock(name="torch")
sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr)
# manually import torch._torch_docs and torch._tensor_docs to trigger
# the mocked _add_docstr and collect docstrings
sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs")
sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs")
return docstrs
def add_docstr_to_hint(docstr: str, hint: str) -> str:
if "..." in hint: # function or method
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
hint = hint[:-3] # remove "..."
return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
else: # attribute or property
return f'{hint}\nr"""{docstr}"""\n'
def gen_pyi(
native_yaml_path: str,
tags_yaml_path: str,
@ -972,11 +1013,15 @@ def gen_pyi(
)
return hint
docstrs = gather_docstrs()
function_hints = []
for name, hints in sorted(unsorted_function_hints.items()):
hints = [replace_special_case(h) for h in hints]
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
docstr = docstrs.get(f"torch.{name}")
if docstr is not None:
hints = [add_docstr_to_hint(docstr, h) for h in hints]
function_hints += hints
# Generate type signatures for Tensor methods
@ -1199,6 +1244,9 @@ def gen_pyi(
for name, hints in sorted(unsorted_tensor_method_hints.items()):
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
docstr = docstrs.get(f"torch._C.TensorBase.{name}")
if docstr is not None:
hints = [add_docstr_to_hint(docstr, h) for h in hints]
tensor_method_hints += hints
# TODO: Missing type hints for nn

View File

@ -2,7 +2,7 @@
import torch._C
from torch._C import _add_docstr as add_docstr
from ._torch_docs import parse_kwargs, reproducibility_notes
from torch._torch_docs import parse_kwargs, reproducibility_notes
def add_docstr_all(method, docstr):

View File

@ -11104,7 +11104,7 @@ always be real-valued, even if :attr:`input` is complex.
.. warning:: If the distance between any two singular values is close to zero, the gradients with respect to
`U` and `V` will be numerically unstable, as they depends on
:math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix
has small singular values, as these gradients also depend on `S⁻¹`.
has small singular values, as these gradients also depend on `S^{-1}`.
.. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique,
as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column.