mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Bring docstring to .pyi file (#114705)
Fixes #37762
Since the original issue hasn't been making progress for more than 3 years, I am attempting to make this PR to at least make some progress forward.
This PR attempts to add docstring to the `.pyi` files. The docstrings are read from [`_torch_docs`](https://github.com/pytorch/pytorch/blob/main/torch/_torch_docs.py) by mocking [`_add_docstr`](9f073ae304/torch/csrc/Module.cpp (L329)
), which is the only function used to add docstring.
Luckily, `_torch_docs` has no dependencies for other components of PyTorch, and can be imported without compiling `torch._C` with `_add_docstr` mocked.
The generated `.pyi` file looks something like the following:
[_VariableFunctions.pyi.txt](https://github.com/pytorch/pytorch/files/13494263/_VariableFunctions.pyi.txt)
<img width="787" alt="image" src="https://github.com/pytorch/pytorch/assets/6421097/73c2e884-f06b-4529-8301-0ca0b9de173c">
And the docstring can be picked up by VSCode:
<img width="839" alt="image" src="https://github.com/pytorch/pytorch/assets/6421097/1999dc89-a591-4c7a-80ac-aa3456672af4">
<img width="908" alt="image" src="https://github.com/pytorch/pytorch/assets/6421097/ecf3fa92-9822-4a3d-9263-d224d87ac288">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114705
Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
cfd0728b24
commit
0dd5deeced
@ -1,7 +1,11 @@
|
||||
import argparse
|
||||
import collections
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from pprint import pformat
|
||||
from typing import Dict, List, Sequence
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from torchgen.api.python import (
|
||||
PythonSignatureGroup,
|
||||
@ -580,6 +584,43 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
We gather the docstrings for torch with the following steps:
|
||||
1. Mock torch and torch._C, which are the only dependencies of the docs files
|
||||
2. Mock the _add_docstr function to save the docstrings
|
||||
3. Import the docs files to trigger mocked _add_docstr and collect docstrings
|
||||
"""
|
||||
|
||||
|
||||
def gather_docstrs() -> Dict[str, str]:
|
||||
docstrs = {}
|
||||
|
||||
def mock_add_docstr(func: Mock, docstr: str) -> None:
|
||||
docstrs[func._extract_mock_name()] = docstr.strip()
|
||||
|
||||
# sys.modules and sys.path are restored after the context manager exits
|
||||
with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]):
|
||||
# mock the torch module and torch._C._add_docstr
|
||||
sys.modules["torch"] = Mock(name="torch")
|
||||
sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr)
|
||||
|
||||
# manually import torch._torch_docs and torch._tensor_docs to trigger
|
||||
# the mocked _add_docstr and collect docstrings
|
||||
sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs")
|
||||
sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs")
|
||||
|
||||
return docstrs
|
||||
|
||||
|
||||
def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
||||
if "..." in hint: # function or method
|
||||
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
|
||||
hint = hint[:-3] # remove "..."
|
||||
return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
|
||||
else: # attribute or property
|
||||
return f'{hint}\nr"""{docstr}"""\n'
|
||||
|
||||
|
||||
def gen_pyi(
|
||||
native_yaml_path: str,
|
||||
tags_yaml_path: str,
|
||||
@ -972,11 +1013,15 @@ def gen_pyi(
|
||||
)
|
||||
return hint
|
||||
|
||||
docstrs = gather_docstrs()
|
||||
function_hints = []
|
||||
for name, hints in sorted(unsorted_function_hints.items()):
|
||||
hints = [replace_special_case(h) for h in hints]
|
||||
if len(hints) > 1:
|
||||
hints = ["@overload\n" + h for h in hints]
|
||||
docstr = docstrs.get(f"torch.{name}")
|
||||
if docstr is not None:
|
||||
hints = [add_docstr_to_hint(docstr, h) for h in hints]
|
||||
function_hints += hints
|
||||
|
||||
# Generate type signatures for Tensor methods
|
||||
@ -1199,6 +1244,9 @@ def gen_pyi(
|
||||
for name, hints in sorted(unsorted_tensor_method_hints.items()):
|
||||
if len(hints) > 1:
|
||||
hints = ["@overload\n" + h for h in hints]
|
||||
docstr = docstrs.get(f"torch._C.TensorBase.{name}")
|
||||
if docstr is not None:
|
||||
hints = [add_docstr_to_hint(docstr, h) for h in hints]
|
||||
tensor_method_hints += hints
|
||||
|
||||
# TODO: Missing type hints for nn
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import torch._C
|
||||
from torch._C import _add_docstr as add_docstr
|
||||
from ._torch_docs import parse_kwargs, reproducibility_notes
|
||||
from torch._torch_docs import parse_kwargs, reproducibility_notes
|
||||
|
||||
|
||||
def add_docstr_all(method, docstr):
|
||||
|
@ -11104,7 +11104,7 @@ always be real-valued, even if :attr:`input` is complex.
|
||||
.. warning:: If the distance between any two singular values is close to zero, the gradients with respect to
|
||||
`U` and `V` will be numerically unstable, as they depends on
|
||||
:math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix
|
||||
has small singular values, as these gradients also depend on `S⁻¹`.
|
||||
has small singular values, as these gradients also depend on `S^{-1}`.
|
||||
|
||||
.. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique,
|
||||
as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column.
|
||||
|
Reference in New Issue
Block a user