[export] update docs to not export raw functions (#121272)

as title

Differential Revision: [D54555101](https://our.internmc.facebook.com/intern/diff/D54555101/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121272
Approved by: https://github.com/zhxchen17
This commit is contained in:
suo
2024-03-05 13:37:45 -08:00
committed by PyTorch MergeBot
parent 862b99b571
commit c3c15eb9a6

View File

@ -22,15 +22,16 @@ serialized.
import torch
from torch.export import export
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
f, args=example_args
Mod(), args=example_args
)
print(exported_program)
@ -393,14 +394,15 @@ branch that is being taken with the given sample inputs. For example:
import torch
from torch.export import export
def fn(x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(fn, example_inputs)
exported_program = export(Mod(), example_inputs)
print(exported_program)
.. code-block::
@ -435,13 +437,14 @@ For example:
import torch
from torch.export import export
def fn(x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(fn, example_inputs)
exported_program = export(Mod(), example_inputs)
print(exported_program)
.. code-block::