mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] Update docs (#142011)
Summary: Update export docs. Including: 1. Update the output graph. 2. Misc fixes for examples. Test Plan: CI Differential Revision: D66726729 Pull Request resolved: https://github.com/pytorch/pytorch/pull/142011 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
471017cbc9
commit
31f2d4eb4e
@ -589,22 +589,27 @@ def register_dataclass(
|
||||
|
||||
Example::
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class InputDataClass:
|
||||
feature: torch.Tensor
|
||||
bias: int
|
||||
|
||||
@dataclass
|
||||
class OutputDataClass:
|
||||
res: torch.Tensor
|
||||
|
||||
torch.export.register_dataclass(InputDataClass)
|
||||
torch.export.register_dataclass(OutputDataClass)
|
||||
|
||||
def fn(o: InputDataClass) -> torch.Tensor:
|
||||
res = res=o.feature + o.bias
|
||||
return OutputDataClass(res=res)
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x: InputDataClass) -> OutputDataClass:
|
||||
res = x.feature + x.bias
|
||||
return OutputDataClass(res=res)
|
||||
|
||||
ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
|
||||
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
|
||||
print(ep)
|
||||
|
||||
"""
|
||||
|
Reference in New Issue
Block a user