mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129764 Approved by: https://github.com/ezyang
92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
|
|
# grab modules from test_jit_hooks.cpp
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from jit.test_hooks_modules import (
|
|
create_forward_tuple_input,
|
|
create_module_forward_multiple_inputs,
|
|
create_module_forward_single_input,
|
|
create_module_hook_return_nothing,
|
|
create_module_multiple_hooks_multiple_inputs,
|
|
create_module_multiple_hooks_single_input,
|
|
create_module_no_forward_input,
|
|
create_module_same_hook_repeated,
|
|
create_submodule_forward_multiple_inputs,
|
|
create_submodule_forward_single_input,
|
|
create_submodule_hook_return_nothing,
|
|
create_submodule_multiple_hooks_multiple_inputs,
|
|
create_submodule_multiple_hooks_single_input,
|
|
create_submodule_same_hook_repeated,
|
|
create_submodule_to_call_directly_with_hooks,
|
|
)
|
|
|
|
|
|
# Create saved modules for JIT forward hooks and pre-hooks
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Serialize a script modules with hooks attached"
|
|
)
|
|
parser.add_argument("--export-script-module-to", required=True)
|
|
options = parser.parse_args()
|
|
global save_name
|
|
save_name = options.export_script_module_to + "_"
|
|
|
|
tests = [
|
|
(
|
|
"test_submodule_forward_single_input",
|
|
create_submodule_forward_single_input(),
|
|
),
|
|
(
|
|
"test_submodule_forward_multiple_inputs",
|
|
create_submodule_forward_multiple_inputs(),
|
|
),
|
|
(
|
|
"test_submodule_multiple_hooks_single_input",
|
|
create_submodule_multiple_hooks_single_input(),
|
|
),
|
|
(
|
|
"test_submodule_multiple_hooks_multiple_inputs",
|
|
create_submodule_multiple_hooks_multiple_inputs(),
|
|
),
|
|
("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
|
|
("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
|
|
("test_module_forward_single_input", create_module_forward_single_input()),
|
|
(
|
|
"test_module_forward_multiple_inputs",
|
|
create_module_forward_multiple_inputs(),
|
|
),
|
|
(
|
|
"test_module_multiple_hooks_single_input",
|
|
create_module_multiple_hooks_single_input(),
|
|
),
|
|
(
|
|
"test_module_multiple_hooks_multiple_inputs",
|
|
create_module_multiple_hooks_multiple_inputs(),
|
|
),
|
|
("test_module_hook_return_nothing", create_module_hook_return_nothing()),
|
|
("test_module_same_hook_repeated", create_module_same_hook_repeated()),
|
|
("test_module_no_forward_input", create_module_no_forward_input()),
|
|
("test_forward_tuple_input", create_forward_tuple_input()),
|
|
(
|
|
"test_submodule_to_call_directly_with_hooks",
|
|
create_submodule_to_call_directly_with_hooks(),
|
|
),
|
|
]
|
|
|
|
for name, model in tests:
|
|
m_scripted = torch.jit.script(model)
|
|
filename = save_name + name + ".pt"
|
|
torch.jit.save(m_scripted, filename)
|
|
|
|
print("OK: completed saving modules with hooks!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|