mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[export] Update example inputs format for DB. (#129982)
Summary: To give user a simpler example code, we are getting rid of ExportArgs in favor of example_args and example_kwargs. Test Plan: CI Differential Revision: D59288920 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129982 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b902b3ee3
commit
042d764872
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
|
||||
from torch._export.db.case import ExportCase, normalize_inputs
|
||||
from torch._export.db.case import ExportCase
|
||||
from torch._export.db.examples import all_examples
|
||||
from torch.export import export
|
||||
|
||||
@ -41,8 +41,13 @@ def generate_example_rst(example_case: ExportCase):
|
||||
2,
|
||||
}, f"more than one @export_rewrite_case decorator in {source_code}"
|
||||
|
||||
more_arguments = ""
|
||||
if example_case.example_kwargs:
|
||||
more_arguments += ", example_kwargs"
|
||||
if example_case.dynamic_shapes:
|
||||
more_arguments += ", dynamic_shapes=dynamic_shapes"
|
||||
|
||||
# Generate contents of the .rst file
|
||||
# TODO(zhxchen17) Update template when we switch to example_args and example_kwargs.
|
||||
title = f"{example_case.name}"
|
||||
doc_contents = f"""{title}
|
||||
{'^' * (len(title))}
|
||||
@ -59,6 +64,8 @@ Original source code:
|
||||
|
||||
{splitted_source_code[0]}
|
||||
|
||||
torch.export.export(model, example_args{more_arguments})
|
||||
|
||||
Result:
|
||||
|
||||
.. code-block::
|
||||
@ -67,11 +74,10 @@ Result:
|
||||
|
||||
# Get resulting graph from dynamo trace
|
||||
try:
|
||||
inputs = normalize_inputs(example_case.example_inputs)
|
||||
exported_program = export(
|
||||
model,
|
||||
inputs.args,
|
||||
inputs.kwargs,
|
||||
example_case.example_args,
|
||||
example_case.example_kwargs,
|
||||
dynamic_shapes=example_case.dynamic_shapes,
|
||||
)
|
||||
graph_output = str(exported_program)
|
||||
|
Reference in New Issue
Block a user