Files
pytorch/docs/source/scripts/exportdb/generate_example_rst.py
Xuehai Pan a3abfa5cb5 [BE][Easy][1/19] enforce style for empty lines in import segments (#129752)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129752
Approved by: https://github.com/ezyang, https://github.com/malfet
2024-07-16 00:42:56 +00:00

191 lines
4.8 KiB
Python

import inspect
import os
import re
from pathlib import Path
import torch
import torch._dynamo as torchdynamo
from torch._export.db.case import ExportCase
from torch._export.db.examples import all_examples
from torch.export import export
PWD = Path(__file__).absolute().parent
ROOT = Path(__file__).absolute().parent.parent.parent.parent
SOURCE = ROOT / Path("source")
EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb")
def generate_example_rst(example_case: ExportCase):
"""
Generates the .rst files for all the examples in db/examples/
"""
model = example_case.model
tags = ", ".join(f":doc:`{tag} <{tag}>`" for tag in example_case.tags)
source_file = (
inspect.getfile(model.__class__)
if isinstance(model, torch.nn.Module)
else inspect.getfile(model)
)
with open(source_file) as file:
source_code = file.read()
source_code = source_code.replace("\n", "\n ")
splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code)
assert len(splitted_source_code) in {
1,
2,
}, f"more than one @export_rewrite_case decorator in {source_code}"
more_arguments = ""
if example_case.example_kwargs:
more_arguments += ", example_kwargs"
if example_case.dynamic_shapes:
more_arguments += ", dynamic_shapes=dynamic_shapes"
# Generate contents of the .rst file
title = f"{example_case.name}"
doc_contents = f"""{title}
{'^' * (len(title))}
.. note::
Tags: {tags}
Support Level: {example_case.support_level.name}
Original source code:
.. code-block:: python
{splitted_source_code[0]}
torch.export.export(model, example_args{more_arguments})
Result:
.. code-block::
"""
# Get resulting graph from dynamo trace
try:
exported_program = export(
model,
example_case.example_args,
example_case.example_kwargs,
dynamic_shapes=example_case.dynamic_shapes,
)
graph_output = str(exported_program)
graph_output = re.sub(r" # File(.|\n)*?\n", "", graph_output)
graph_output = graph_output.replace("\n", "\n ")
output = f" {graph_output}"
except torchdynamo.exc.Unsupported as e:
output = " Unsupported: " + str(e).split("\n")[0]
except AssertionError as e:
output = " AssertionError: " + str(e).split("\n")[0]
except RuntimeError as e:
output = " RuntimeError: " + str(e).split("\n")[0]
doc_contents += output + "\n"
if len(splitted_source_code) == 2:
doc_contents += f"""\n
You can rewrite the example above to something like the following:
.. code-block:: python
{splitted_source_code[1]}
"""
return doc_contents
def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules):
"""
Generates the index.rst file
"""
support_contents = ""
for k, v in support_level_to_modules.items():
support_level = k.name.lower().replace("_", " ").title()
module_contents = "\n\n".join(v)
support_contents += f"""
{support_level}
{'-' * (len(support_level))}
{module_contents}
"""
tag_names = "\n ".join(t for t in tag_to_modules.keys())
with open(os.path.join(PWD, "blurb.txt")) as file:
blurb = file.read()
# Generate contents of the .rst file
doc_contents = f""".. _torch.export_db:
ExportDB
========
{blurb}
.. toctree::
:maxdepth: 1
:caption: Tags
{tag_names}
{support_contents}
"""
with open(os.path.join(EXPORTDB_SOURCE, "index.rst"), "w") as f:
f.write(doc_contents)
def generate_tag_rst(tag_to_modules):
"""
For each tag that shows up in each ExportCase.tag, generate an .rst file
containing all the examples that have that tag.
"""
for tag, modules_rst in tag_to_modules.items():
doc_contents = f"{tag}\n{'=' * (len(tag) + 4)}\n"
full_modules_rst = "\n\n".join(modules_rst)
full_modules_rst = re.sub(
r"={3,}", lambda match: "-" * len(match.group()), full_modules_rst
)
doc_contents += full_modules_rst
with open(os.path.join(EXPORTDB_SOURCE, f"{tag}.rst"), "w") as f:
f.write(doc_contents)
def generate_rst():
if not os.path.exists(EXPORTDB_SOURCE):
os.makedirs(EXPORTDB_SOURCE)
example_cases = all_examples()
tag_to_modules = {}
support_level_to_modules = {}
for example_case in example_cases.values():
doc_contents = generate_example_rst(example_case)
for tag in example_case.tags:
tag_to_modules.setdefault(tag, []).append(doc_contents)
support_level_to_modules.setdefault(example_case.support_level, []).append(
doc_contents
)
generate_tag_rst(tag_to_modules)
generate_index_rst(example_cases, tag_to_modules, support_level_to_modules)
if __name__ == "__main__":
generate_rst()