mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] update docs to not export raw functions (#121272)
as title Differential Revision: [D54555101](https://our.internmc.facebook.com/intern/diff/D54555101/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121272 Approved by: https://github.com/zhxchen17
This commit is contained in:
@ -22,15 +22,16 @@ serialized.
|
||||
import torch
|
||||
from torch.export import export
|
||||
|
||||
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
return a + b
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
return a + b
|
||||
|
||||
example_args = (torch.randn(10, 10), torch.randn(10, 10))
|
||||
|
||||
exported_program: torch.export.ExportedProgram = export(
|
||||
f, args=example_args
|
||||
Mod(), args=example_args
|
||||
)
|
||||
print(exported_program)
|
||||
|
||||
@ -393,14 +394,15 @@ branch that is being taken with the given sample inputs. For example:
|
||||
import torch
|
||||
from torch.export import export
|
||||
|
||||
def fn(x):
|
||||
if x.shape[0] > 5:
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
if x.shape[0] > 5:
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
example_inputs = (torch.rand(10, 2),)
|
||||
exported_program = export(fn, example_inputs)
|
||||
exported_program = export(Mod(), example_inputs)
|
||||
print(exported_program)
|
||||
|
||||
.. code-block::
|
||||
@ -435,13 +437,14 @@ For example:
|
||||
import torch
|
||||
from torch.export import export
|
||||
|
||||
def fn(x: torch.Tensor, const: int, times: int):
|
||||
for i in range(times):
|
||||
x = x + const
|
||||
return x
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, const: int, times: int):
|
||||
for i in range(times):
|
||||
x = x + const
|
||||
return x
|
||||
|
||||
example_inputs = (torch.rand(2, 2), 1, 3)
|
||||
exported_program = export(fn, example_inputs)
|
||||
exported_program = export(Mod(), example_inputs)
|
||||
print(exported_program)
|
||||
|
||||
.. code-block::
|
||||
|
Reference in New Issue
Block a user