mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
import argparse
|
|
import os.path
|
|
import sys
|
|
|
|
import torch
|
|
|
|
|
|
def get_custom_op_library_path():
|
|
if sys.platform.startswith("win32"):
|
|
library_filename = "custom_ops.dll"
|
|
elif sys.platform.startswith("darwin"):
|
|
library_filename = "libcustom_ops.dylib"
|
|
else:
|
|
library_filename = "libcustom_ops.so"
|
|
path = os.path.abspath(f"build/{library_filename}")
|
|
assert os.path.exists(path), path
|
|
return path
|
|
|
|
|
|
class Model(torch.jit.ScriptModule):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(torch.eye(5))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.ops.custom.op_with_defaults(input)[0] + 1
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Serialize a script module with custom ops"
|
|
)
|
|
parser.add_argument("--export-script-module-to", required=True)
|
|
options = parser.parse_args()
|
|
|
|
torch.ops.load_library(get_custom_op_library_path())
|
|
|
|
model = Model()
|
|
model.save(options.export_script_module_to)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|