mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][PYFMT] migrate PYFMT for torch/[e-n]*/
to ruff format
(#144553)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144553 Approved by: https://github.com/ezyang ghstack dependencies: #144551
This commit is contained in:
committed by
PyTorch MergeBot
parent
95cb42c45d
commit
2e0e08588e
@ -358,22 +358,24 @@ def save(
|
||||
import torch
|
||||
import io
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 10
|
||||
|
||||
|
||||
ep = torch.export.export(MyModule(), (torch.randn(5),))
|
||||
|
||||
# Save to file
|
||||
torch.export.save(ep, 'exported_program.pt2')
|
||||
torch.export.save(ep, "exported_program.pt2")
|
||||
|
||||
# Save to io.BytesIO buffer
|
||||
buffer = io.BytesIO()
|
||||
torch.export.save(ep, buffer)
|
||||
|
||||
# Save with extra files
|
||||
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
|
||||
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
|
||||
extra_files = {"foo.txt": b"bar".decode("utf-8")}
|
||||
torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
|
||||
|
||||
"""
|
||||
if not isinstance(ep, ExportedProgram):
|
||||
@ -427,18 +429,18 @@ def load(
|
||||
import io
|
||||
|
||||
# Load ExportedProgram from file
|
||||
ep = torch.export.load('exported_program.pt2')
|
||||
ep = torch.export.load("exported_program.pt2")
|
||||
|
||||
# Load ExportedProgram from io.BytesIO object
|
||||
with open('exported_program.pt2', 'rb') as f:
|
||||
with open("exported_program.pt2", "rb") as f:
|
||||
buffer = io.BytesIO(f.read())
|
||||
buffer.seek(0)
|
||||
ep = torch.export.load(buffer)
|
||||
|
||||
# Load with extra files.
|
||||
extra_files = {'foo.txt': ''} # values will be replaced with data
|
||||
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
|
||||
print(extra_files['foo.txt'])
|
||||
extra_files = {"foo.txt": ""} # values will be replaced with data
|
||||
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
|
||||
print(extra_files["foo.txt"])
|
||||
print(ep(torch.randn(5)))
|
||||
"""
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
@ -572,24 +574,29 @@ def register_dataclass(
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDataClass:
|
||||
feature: torch.Tensor
|
||||
bias: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputDataClass:
|
||||
res: torch.Tensor
|
||||
|
||||
|
||||
torch.export.register_dataclass(InputDataClass)
|
||||
torch.export.register_dataclass(OutputDataClass)
|
||||
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x: InputDataClass) -> OutputDataClass:
|
||||
res = x.feature + x.bias
|
||||
return OutputDataClass(res=res)
|
||||
|
||||
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
|
||||
|
||||
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),))
|
||||
print(ep)
|
||||
|
||||
"""
|
||||
|
Reference in New Issue
Block a user