Files
pytorch/docs/source/scripts/exportdb/generate_example_rst.py
ydwu4 6abb8c382c [export] add kwargs support for export. (#105337)
Solving #105242.

During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example:
```python
def f(arg1, arg2, kw1, kw2):
  pass

args = (arg1, arg2)
kwargs =  {"kw2":arg3, "kw1":arg4}

torch.export(f, args, kwargs)
```
The signature changes mutiple times during export process in the following order:
1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all  kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have:
```python
def gm_torch_level(arg1, arg2, kw2, kw1)
```
Such difference is acceptable as it's transparent to users of export.

2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args):
``` python
flat_args, _ = pytree.tree_flatten(pos_or_kw_args)
def gm_aot_export(*flat_args)
```

3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected.  We can combine the two steps into one with :
```python
_, in_spec = pytree.tree_flatten((args, kwargs))

# Then during exported_program.__call__(*args, **kwargs)
flat_args  = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
```
, where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec.

Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it.

Test plan:
See added tests in test_export.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
2023-07-20 19:53:08 +00:00

176 lines
4.4 KiB
Python

import inspect
import os
import re
from pathlib import Path
import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.db.case import ExportCase, normalize_inputs
from torch._export.db.examples import all_examples
PWD = Path(__file__).absolute().parent
ROOT = Path(__file__).absolute().parent.parent.parent.parent
SOURCE = ROOT / Path("source")
EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb")
def generate_example_rst(example_case: ExportCase):
"""
Generates the .rst files for all the examples in db/examples/
"""
model = example_case.model
tags = ", ".join(f":doc:`{tag} <{tag}>`" for tag in example_case.tags)
source_file = (
inspect.getfile(model.__class__)
if isinstance(model, torch.nn.Module)
else inspect.getfile(model)
)
with open(source_file) as file:
source_code = file.read()
source_code = re.sub(r"from torch\._export\.db\.case import .*\n", "", source_code)
source_code = re.sub(r"@export_case\((.|\n)*?\)\n", "", source_code)
source_code = source_code.replace("\n", "\n ")
splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code)
assert len(splitted_source_code) in {
1,
2,
}, f"more than one @export_rewrite_case decorator in {source_code}"
# Generate contents of the .rst file
title = f"{example_case.name}"
doc_contents = f"""{title}
{'^' * (len(title))}
.. note::
Tags: {tags}
Support Level: {example_case.support_level.name}
Original source code:
.. code-block:: python
{splitted_source_code[0]}
Result:
.. code-block::
"""
# Get resulting graph from dynamo trace
try:
inputs = normalize_inputs(example_case.example_inputs)
exported_program = export(
model,
inputs.args,
inputs.kwargs,
constraints=example_case.constraints,
)
graph_output = str(exported_program)
graph_output = re.sub(r" # File(.|\n)*?\n", "", graph_output)
graph_output = graph_output.replace("\n", "\n ")
output = f" {graph_output}"
except torchdynamo.exc.Unsupported as e:
output = " Unsupported: " + str(e).split("\n")[0]
doc_contents += output + "\n"
if len(splitted_source_code) == 2:
doc_contents += f"""\n
You can rewrite the example above to something like the following:
.. code-block:: python
{splitted_source_code[1]}
"""
return doc_contents
def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules):
"""
Generates the index.rst file
"""
support_contents = ""
for k, v in support_level_to_modules.items():
support_level = k.name.lower().replace("_", " ").title()
module_contents = "\n\n".join(v)
support_contents += f"""
{support_level}
{'-' * (len(support_level))}
{module_contents}
"""
tag_names = "\n ".join(t for t in tag_to_modules.keys())
with open(os.path.join(PWD, "blurb.txt")) as file:
blurb = file.read()
# Generate contents of the .rst file
doc_contents = f"""ExportDB
========
{blurb}
.. toctree::
:maxdepth: 1
:caption: Tags
{tag_names}
{support_contents}
"""
with open(os.path.join(EXPORTDB_SOURCE, "index.rst"), "w") as f:
f.write(doc_contents)
def generate_tag_rst(tag_to_modules):
"""
For each tag that shows up in each ExportCase.tag, generate an .rst file
containing all the examples that have that tag.
"""
for tag, modules_rst in tag_to_modules.items():
doc_contents = f"{tag}\n{'=' * (len(tag) + 4)}\n"
doc_contents += "\n\n".join(modules_rst).replace("=", "-")
with open(os.path.join(EXPORTDB_SOURCE, f"{tag}.rst"), "w") as f:
f.write(doc_contents)
def generate_rst():
if not os.path.exists(EXPORTDB_SOURCE):
os.makedirs(EXPORTDB_SOURCE)
example_cases = all_examples()
tag_to_modules = {}
support_level_to_modules = {}
for example_case in example_cases.values():
doc_contents = generate_example_rst(example_case)
for tag in example_case.tags:
tag_to_modules.setdefault(tag, []).append(doc_contents)
support_level_to_modules.setdefault(example_case.support_level, []).append(doc_contents)
generate_tag_rst(tag_to_modules)
generate_index_rst(example_cases, tag_to_modules, support_level_to_modules)
if __name__ == "__main__":
generate_rst()