import inspect import os import re from pathlib import Path import torch import torch._dynamo as torchdynamo from torch._export.db.case import ExportCase from torch._export.db.examples import all_examples from torch.export import export PWD = Path(__file__).absolute().parent ROOT = Path(__file__).absolute().parent.parent.parent.parent SOURCE = ROOT / Path("source") EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb") def generate_example_rst(example_case: ExportCase): """ Generates the .rst files for all the examples in db/examples/ """ model = example_case.model tags = ", ".join(f":doc:`{tag} <{tag}>`" for tag in example_case.tags) source_file = ( inspect.getfile(model.__class__) if isinstance(model, torch.nn.Module) else inspect.getfile(model) ) with open(source_file) as file: source_code = file.read() source_code = source_code.replace("\n", "\n ") splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code) assert len(splitted_source_code) in { 1, 2, }, f"more than one @export_rewrite_case decorator in {source_code}" more_arguments = "" if example_case.example_kwargs: more_arguments += ", example_kwargs" if example_case.dynamic_shapes: more_arguments += ", dynamic_shapes=dynamic_shapes" # Generate contents of the .rst file title = f"{example_case.name}" doc_contents = f"""{title} {'^' * (len(title))} .. note:: Tags: {tags} Support Level: {example_case.support_level.name} Original source code: .. code-block:: python {splitted_source_code[0]} torch.export.export(model, example_args{more_arguments}) Result: .. code-block:: """ # Get resulting graph from dynamo trace try: exported_program = export( model, example_case.example_args, example_case.example_kwargs, dynamic_shapes=example_case.dynamic_shapes, strict=True, ) graph_output = str(exported_program) graph_output = re.sub(r" # File(.|\n)*?\n", "", graph_output) graph_output = graph_output.replace("\n", "\n ") output = f" {graph_output}" except torchdynamo.exc.Unsupported as e: output = " Unsupported: " + str(e).split("\n")[0] except AssertionError as e: output = " AssertionError: " + str(e).split("\n")[0] except RuntimeError as e: output = " RuntimeError: " + str(e).split("\n")[0] doc_contents += output + "\n" if len(splitted_source_code) == 2: doc_contents += f"""\n You can rewrite the example above to something like the following: .. code-block:: python {splitted_source_code[1]} """ return doc_contents def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules): """ Generates the index.rst file """ support_contents = "" for k, v in support_level_to_modules.items(): support_level = k.name.lower().replace("_", " ").title() module_contents = "\n\n".join(v) support_contents += f""" {support_level} {'-' * (len(support_level))} {module_contents} """ tag_names = "\n ".join(t for t in tag_to_modules.keys()) with open(os.path.join(PWD, "blurb.txt")) as file: blurb = file.read() # Generate contents of the .rst file doc_contents = f""".. _torch.export_db: ExportDB ======== {blurb} .. toctree:: :maxdepth: 1 :caption: Tags {tag_names} {support_contents} """ with open(os.path.join(EXPORTDB_SOURCE, "index.rst"), "w") as f: f.write(doc_contents) def generate_tag_rst(tag_to_modules): """ For each tag that shows up in each ExportCase.tag, generate an .rst file containing all the examples that have that tag. """ for tag, modules_rst in tag_to_modules.items(): doc_contents = f"{tag}\n{'=' * (len(tag) + 4)}\n" full_modules_rst = "\n\n".join(modules_rst) full_modules_rst = re.sub( r"={3,}", lambda match: "-" * len(match.group()), full_modules_rst ) doc_contents += full_modules_rst with open(os.path.join(EXPORTDB_SOURCE, f"{tag}.rst"), "w") as f: f.write(doc_contents) def generate_rst(): if not os.path.exists(EXPORTDB_SOURCE): os.makedirs(EXPORTDB_SOURCE) example_cases = all_examples() tag_to_modules = {} support_level_to_modules = {} for example_case in example_cases.values(): doc_contents = generate_example_rst(example_case) for tag in example_case.tags: tag_to_modules.setdefault(tag, []).append(doc_contents) support_level_to_modules.setdefault(example_case.support_level, []).append( doc_contents ) generate_tag_rst(tag_to_modules) generate_index_rst(example_cases, tag_to_modules, support_level_to_modules) if __name__ == "__main__": generate_rst()