mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Summary: Provide a standalone path to compile and run a ExportedProgram in C. Test Plan: (1) Generate a compiled model from ExportedProgram ``` python generate_lowered_cpu.py --input-path /tmp/$USER/ep.pt --output-path /tmp/$USER/final.pt ``` (2) Compile a standalone test runner ``` TORCH_ROOT_DIR=/data/users/$USER/pytorch sh standalone_compile.sh standalone_test.cpp standalone_test.out ``` (3) Run test for the compiled model in step (1) ``` LD_LIBRARY_PATH=/data/users/$USER/pytorch/build/lib ./standalone_test.out /tmp/$USER/final.pt ``` Differential Revision: D66872380 Pull Request resolved: https://github.com/pytorch/pytorch/pull/142327 Approved by: https://github.com/hl475
63 lines
1.4 KiB
Python
63 lines
1.4 KiB
Python
import copy
|
|
|
|
import click
|
|
|
|
import torch
|
|
|
|
|
|
class Serializer(torch.nn.Module):
|
|
def __init__(self, data):
|
|
super().__init__()
|
|
for key in data:
|
|
setattr(self, key, data[key])
|
|
|
|
|
|
@click.command()
|
|
@click.option(
|
|
"--input-path",
|
|
type=str,
|
|
default="",
|
|
required=True,
|
|
help="path to the ExportedProgram",
|
|
)
|
|
@click.option(
|
|
"--output-path",
|
|
type=str,
|
|
default="",
|
|
required=True,
|
|
)
|
|
def main(
|
|
input_path: str = "",
|
|
output_path: str = "",
|
|
) -> None:
|
|
data = {}
|
|
ep = torch.export.load(input_path)
|
|
with torch.no_grad():
|
|
example_inputs = ep.example_inputs[0]
|
|
# Get scripted original module.
|
|
module = torch.jit.trace(copy.deepcopy(ep.module()), example_inputs)
|
|
|
|
# Get aot compiled module.
|
|
so_path = torch._inductor.aot_compile(ep.module(), example_inputs)
|
|
runner = torch.fx.Interpreter(ep.module())
|
|
output = runner.run(example_inputs)
|
|
if isinstance(output, (list, tuple)):
|
|
output = list(output)
|
|
else:
|
|
output = [output]
|
|
|
|
data.update(
|
|
{
|
|
"script_module": module,
|
|
"model_so_path": so_path,
|
|
"inputs": list(example_inputs),
|
|
"outputs": output,
|
|
}
|
|
)
|
|
|
|
torch.jit.script(Serializer(data)).save(output_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|