[export] Update example inputs format for DB. (#129982)

Summary: To give user a simpler example code, we are getting rid of ExportArgs in favor of example_args and example_kwargs.

Test Plan: CI

Differential Revision: D59288920

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129982
Approved by: https://github.com/angelayi
This commit is contained in:
Zhengxu Chen
2024-07-03 17:53:15 +00:00
committed by PyTorch MergeBot
parent 9b902b3ee3
commit 042d764872
40 changed files with 96 additions and 96 deletions

View File

@ -6,7 +6,7 @@ from pathlib import Path
import torch
import torch._dynamo as torchdynamo
from torch._export.db.case import ExportCase, normalize_inputs
from torch._export.db.case import ExportCase
from torch._export.db.examples import all_examples
from torch.export import export
@ -41,8 +41,13 @@ def generate_example_rst(example_case: ExportCase):
2,
}, f"more than one @export_rewrite_case decorator in {source_code}"
more_arguments = ""
if example_case.example_kwargs:
more_arguments += ", example_kwargs"
if example_case.dynamic_shapes:
more_arguments += ", dynamic_shapes=dynamic_shapes"
# Generate contents of the .rst file
# TODO(zhxchen17) Update template when we switch to example_args and example_kwargs.
title = f"{example_case.name}"
doc_contents = f"""{title}
{'^' * (len(title))}
@ -59,6 +64,8 @@ Original source code:
{splitted_source_code[0]}
torch.export.export(model, example_args{more_arguments})
Result:
.. code-block::
@ -67,11 +74,10 @@ Result:
# Get resulting graph from dynamo trace
try:
inputs = normalize_inputs(example_case.example_inputs)
exported_program = export(
model,
inputs.args,
inputs.kwargs,
example_case.example_args,
example_case.example_kwargs,
dynamic_shapes=example_case.dynamic_shapes,
)
graph_output = str(exported_program)