mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Solving #105242. During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example: ```python def f(arg1, arg2, kw1, kw2): pass args = (arg1, arg2) kwargs = {"kw2":arg3, "kw1":arg4} torch.export(f, args, kwargs) ``` The signature changes mutiple times during export process in the following order: 1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have: ```python def gm_torch_level(arg1, arg2, kw2, kw1) ``` Such difference is acceptable as it's transparent to users of export. 2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args): ``` python flat_args, _ = pytree.tree_flatten(pos_or_kw_args) def gm_aot_export(*flat_args) ``` 3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected. We can combine the two steps into one with : ```python _, in_spec = pytree.tree_flatten((args, kwargs)) # Then during exported_program.__call__(*args, **kwargs) flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec) ``` , where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec. Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it. Test plan: See added tests in test_export.py. Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337 Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
176 lines
4.4 KiB
Python
176 lines
4.4 KiB
Python
import inspect
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch._export import export
|
|
|
|
from torch._export.db.case import ExportCase, normalize_inputs
|
|
from torch._export.db.examples import all_examples
|
|
|
|
|
|
PWD = Path(__file__).absolute().parent
|
|
ROOT = Path(__file__).absolute().parent.parent.parent.parent
|
|
SOURCE = ROOT / Path("source")
|
|
EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb")
|
|
|
|
|
|
def generate_example_rst(example_case: ExportCase):
|
|
"""
|
|
Generates the .rst files for all the examples in db/examples/
|
|
"""
|
|
|
|
model = example_case.model
|
|
|
|
tags = ", ".join(f":doc:`{tag} <{tag}>`" for tag in example_case.tags)
|
|
|
|
source_file = (
|
|
inspect.getfile(model.__class__)
|
|
if isinstance(model, torch.nn.Module)
|
|
else inspect.getfile(model)
|
|
)
|
|
with open(source_file) as file:
|
|
source_code = file.read()
|
|
source_code = re.sub(r"from torch\._export\.db\.case import .*\n", "", source_code)
|
|
source_code = re.sub(r"@export_case\((.|\n)*?\)\n", "", source_code)
|
|
source_code = source_code.replace("\n", "\n ")
|
|
splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code)
|
|
|
|
assert len(splitted_source_code) in {
|
|
1,
|
|
2,
|
|
}, f"more than one @export_rewrite_case decorator in {source_code}"
|
|
|
|
# Generate contents of the .rst file
|
|
title = f"{example_case.name}"
|
|
doc_contents = f"""{title}
|
|
{'^' * (len(title))}
|
|
|
|
.. note::
|
|
|
|
Tags: {tags}
|
|
|
|
Support Level: {example_case.support_level.name}
|
|
|
|
Original source code:
|
|
|
|
.. code-block:: python
|
|
|
|
{splitted_source_code[0]}
|
|
|
|
Result:
|
|
|
|
.. code-block::
|
|
|
|
"""
|
|
|
|
# Get resulting graph from dynamo trace
|
|
try:
|
|
inputs = normalize_inputs(example_case.example_inputs)
|
|
exported_program = export(
|
|
model,
|
|
inputs.args,
|
|
inputs.kwargs,
|
|
constraints=example_case.constraints,
|
|
)
|
|
graph_output = str(exported_program)
|
|
graph_output = re.sub(r" # File(.|\n)*?\n", "", graph_output)
|
|
graph_output = graph_output.replace("\n", "\n ")
|
|
output = f" {graph_output}"
|
|
except torchdynamo.exc.Unsupported as e:
|
|
output = " Unsupported: " + str(e).split("\n")[0]
|
|
|
|
doc_contents += output + "\n"
|
|
|
|
if len(splitted_source_code) == 2:
|
|
doc_contents += f"""\n
|
|
You can rewrite the example above to something like the following:
|
|
|
|
.. code-block:: python
|
|
|
|
{splitted_source_code[1]}
|
|
|
|
"""
|
|
|
|
return doc_contents
|
|
|
|
|
|
def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules):
|
|
"""
|
|
Generates the index.rst file
|
|
"""
|
|
|
|
support_contents = ""
|
|
for k, v in support_level_to_modules.items():
|
|
support_level = k.name.lower().replace("_", " ").title()
|
|
module_contents = "\n\n".join(v)
|
|
support_contents += f"""
|
|
{support_level}
|
|
{'-' * (len(support_level))}
|
|
|
|
{module_contents}
|
|
"""
|
|
|
|
tag_names = "\n ".join(t for t in tag_to_modules.keys())
|
|
|
|
with open(os.path.join(PWD, "blurb.txt")) as file:
|
|
blurb = file.read()
|
|
|
|
# Generate contents of the .rst file
|
|
doc_contents = f"""ExportDB
|
|
========
|
|
|
|
{blurb}
|
|
|
|
.. toctree::
|
|
:maxdepth: 1
|
|
:caption: Tags
|
|
|
|
{tag_names}
|
|
|
|
{support_contents}
|
|
"""
|
|
|
|
with open(os.path.join(EXPORTDB_SOURCE, "index.rst"), "w") as f:
|
|
f.write(doc_contents)
|
|
|
|
|
|
def generate_tag_rst(tag_to_modules):
|
|
"""
|
|
For each tag that shows up in each ExportCase.tag, generate an .rst file
|
|
containing all the examples that have that tag.
|
|
"""
|
|
|
|
for tag, modules_rst in tag_to_modules.items():
|
|
doc_contents = f"{tag}\n{'=' * (len(tag) + 4)}\n"
|
|
doc_contents += "\n\n".join(modules_rst).replace("=", "-")
|
|
|
|
with open(os.path.join(EXPORTDB_SOURCE, f"{tag}.rst"), "w") as f:
|
|
f.write(doc_contents)
|
|
|
|
|
|
def generate_rst():
|
|
if not os.path.exists(EXPORTDB_SOURCE):
|
|
os.makedirs(EXPORTDB_SOURCE)
|
|
|
|
example_cases = all_examples()
|
|
tag_to_modules = {}
|
|
support_level_to_modules = {}
|
|
for example_case in example_cases.values():
|
|
|
|
doc_contents = generate_example_rst(example_case)
|
|
|
|
for tag in example_case.tags:
|
|
tag_to_modules.setdefault(tag, []).append(doc_contents)
|
|
|
|
support_level_to_modules.setdefault(example_case.support_level, []).append(doc_contents)
|
|
|
|
generate_tag_rst(tag_to_modules)
|
|
generate_index_rst(example_cases, tag_to_modules, support_level_to_modules)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generate_rst()
|