mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[WIP] JIT Static Hooks: cpp tests (#49547)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49547 Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D25771118 Pulled By: Lilyjjo fbshipit-source-id: cd8a58ff008a1c5d65ccbfbcbcb0214781ece16f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3b88e1b0e7
commit
22902b9242
1
.gitignore
vendored
1
.gitignore
vendored
@ -42,6 +42,7 @@ test/.coverage
|
||||
test/.hypothesis/
|
||||
test/cpp/api/mnist
|
||||
test/custom_operator/model.pt
|
||||
test/jit_hooks/*.pt
|
||||
test/data/legacy_modules.t7
|
||||
test/data/*.pt
|
||||
test/backward_compatibility/nightly_schemas.txt
|
||||
|
@ -223,6 +223,18 @@ else
|
||||
popd
|
||||
assert_git_not_dirty
|
||||
|
||||
# Build jit hook tests
|
||||
JIT_HOOK_BUILD="$PWD/../jit-hook-build"
|
||||
JIT_HOOK_TEST="$PWD/test/jit_hooks"
|
||||
python --version
|
||||
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
|
||||
mkdir "$JIT_HOOK_BUILD"
|
||||
pushd "$JIT_HOOK_BUILD"
|
||||
cmake "$JIT_HOOK_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPYTHON_EXECUTABLE="$(which python)"
|
||||
make VERBOSE=1
|
||||
popd
|
||||
assert_git_not_dirty
|
||||
|
||||
# Build custom backend tests.
|
||||
CUSTOM_BACKEND_BUILD="$PWD/../custom-backend-build"
|
||||
CUSTOM_BACKEND_TEST="$PWD/test/custom_backend"
|
||||
|
@ -134,11 +134,31 @@ test_custom_script_ops() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_jit_hooks() {
|
||||
echo "Testing jit hooks in cpp"
|
||||
pushd test/jit_hooks
|
||||
# Build the custom operator library.
|
||||
rm -rf build && mkdir build
|
||||
pushd build
|
||||
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
|
||||
CMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" cmake ..
|
||||
make VERBOSE=1
|
||||
popd
|
||||
|
||||
# Run tests Python-side and export a script module.
|
||||
python model.py --export-script-module=model
|
||||
# Run tests C++-side and load the exported script module.
|
||||
build/test_jit_hooks ./model
|
||||
popd
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
if [ -z "${BUILD_ENVIRONMENT}" ] || [[ "${BUILD_ENVIRONMENT}" == *-test ]]; then
|
||||
test_python_all
|
||||
test_libtorch
|
||||
test_custom_script_ops
|
||||
test_jit_hooks
|
||||
test_custom_backend
|
||||
else
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *-test1 ]]; then
|
||||
@ -146,6 +166,7 @@ else
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *-test2 ]]; then
|
||||
test_libtorch
|
||||
test_custom_script_ops
|
||||
test_jit_hooks
|
||||
test_custom_backend
|
||||
fi
|
||||
fi
|
||||
|
@ -242,6 +242,21 @@ test_custom_script_ops() {
|
||||
fi
|
||||
}
|
||||
|
||||
test_jit_hooks() {
|
||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* ]] && [[ "$BUILD_ENVIRONMENT" != *asan* ]] ; then
|
||||
echo "Testing jit hooks in cpp"
|
||||
HOOK_BUILD="$PWD/../jit-hook-build"
|
||||
pushd test/jit_hooks
|
||||
cp -a "$HOOK_BUILD" build
|
||||
# Run tests Python-side and export the script modules with hooks
|
||||
python model.py --export-script-module=model
|
||||
# Run tests C++-side and load the exported script modules
|
||||
build/test_jit_hooks ./model
|
||||
popd
|
||||
assert_git_not_dirty
|
||||
fi
|
||||
}
|
||||
|
||||
test_torch_function_benchmark() {
|
||||
echo "Testing __torch_function__ benchmarks"
|
||||
pushd benchmarks/overrides_benchmark
|
||||
|
9
test/jit_hooks/CMakeLists.txt
Normal file
9
test/jit_hooks/CMakeLists.txt
Normal file
@ -0,0 +1,9 @@
|
||||
# Basic CMake setup
|
||||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
|
||||
project(jit_hooks)
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_executable(test_jit_hooks test_jit_hooks.cpp)
|
||||
set_property(TARGET test_jit_hooks PROPERTY CXX_STANDARD 14)
|
||||
target_link_libraries(test_jit_hooks "${TORCH_LIBRARIES}")
|
50
test/jit_hooks/model.py
Normal file
50
test/jit_hooks/model.py
Normal file
@ -0,0 +1,50 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
|
||||
# grab modules from test_jit_hooks.cpp
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from jit.test_hooks_modules import *
|
||||
|
||||
# Create saved modules for JIT forward hooks and pre-hooks
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Serialize a script modules with hooks attached"
|
||||
)
|
||||
parser.add_argument("--export-script-module-to", required=True)
|
||||
options = parser.parse_args()
|
||||
global save_name
|
||||
save_name = options.export_script_module_to + "_"
|
||||
|
||||
tests = [
|
||||
("test_submodule_forward_single_input", create_submodule_forward_single_input()),
|
||||
("test_submodule_forward_multiple_inputs", create_submodule_forward_multiple_inputs()),
|
||||
("test_submodule_multiple_hooks_single_input", create_submodule_multiple_hooks_single_input()),
|
||||
("test_submodule_multiple_hooks_multiple_inputs", create_submodule_multiple_hooks_multiple_inputs()),
|
||||
("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
|
||||
("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
|
||||
|
||||
("test_module_forward_single_input", create_module_forward_single_input()),
|
||||
("test_module_forward_multiple_inputs", create_module_forward_multiple_inputs()),
|
||||
("test_module_multiple_hooks_single_input", create_module_multiple_hooks_single_input()),
|
||||
("test_module_multiple_hooks_multiple_inputs", create_module_multiple_hooks_multiple_inputs()),
|
||||
("test_module_hook_return_nothing", create_module_hook_return_nothing()),
|
||||
("test_module_same_hook_repeated", create_module_same_hook_repeated()),
|
||||
|
||||
("test_module_no_forward_input", create_module_no_forward_input()),
|
||||
("test_forward_tuple_input", create_forward_tuple_input()),
|
||||
("test_submodule_to_call_directly_with_hooks", create_submodule_to_call_directly_with_hooks())
|
||||
]
|
||||
|
||||
for name, model in tests:
|
||||
m_scripted = torch.jit.script(model)
|
||||
filename = save_name + name + ".pt"
|
||||
torch.jit.save(m_scripted, filename)
|
||||
|
||||
print("OK: completed saving modules with hooks!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
136
test/jit_hooks/test_jit_hooks.cpp
Normal file
136
test/jit_hooks/test_jit_hooks.cpp
Normal file
@ -0,0 +1,136 @@
|
||||
#include <torch/script.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
void test_module_forward_invocation_no_hooks_run(
|
||||
const std::string &path_to_exported_script_module) {
|
||||
std::cout << "testing: "
|
||||
<< "test_module_forward_invocation_no_hooks_run" << std::endl;
|
||||
torch::jit::Module module =
|
||||
torch::jit::load(path_to_exported_script_module + "_" +
|
||||
"test_module_forward_multiple_inputs" + ".pt");
|
||||
std::vector<torch::jit::IValue> inputs = {torch::List<std::string>({"a"}),
|
||||
torch::jit::IValue("no_pre_hook")};
|
||||
|
||||
auto output = module(inputs);
|
||||
auto output_forward = module.forward(inputs);
|
||||
torch::jit::IValue correct_direct_output =
|
||||
std::tuple<torch::List<std::string>, std::string>(
|
||||
{"a", "outer_mod_name", "inner_mod_name"}, "no_pre_hook_");
|
||||
std::cout << "----- module output: " << output << std::endl;
|
||||
std::cout << "----- module forward output: " << output_forward << std::endl;
|
||||
AT_ASSERT(correct_direct_output == output_forward);
|
||||
}
|
||||
|
||||
void test_submodule_called_directly_with_hooks(
|
||||
const std::string &path_to_exported_script_module) {
|
||||
std::cout << "testing: "
|
||||
<< "test_submodule_to_call_directly_with_hooks" << std::endl;
|
||||
torch::jit::Module module =
|
||||
torch::jit::load(path_to_exported_script_module + "_" +
|
||||
"test_submodule_to_call_directly_with_hooks" + ".pt");
|
||||
torch::jit::Module submodule = *module.modules().begin();
|
||||
std::vector<torch::jit::IValue> inputs = {"a"};
|
||||
|
||||
auto output = submodule(inputs);
|
||||
torch::jit::IValue correct_output = "pre_hook_override_name_inner_mod_fh";
|
||||
std::cout << "----- submodule's output: " << output << std::endl;
|
||||
std::cout << "----- expected output : " << correct_output << std::endl;
|
||||
AT_ASSERT(correct_output == correct_output);
|
||||
}
|
||||
|
||||
struct HooksTestCase {
|
||||
std::string name;
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
torch::jit::IValue output;
|
||||
HooksTestCase(std::string name, std::vector<torch::jit::IValue> inputs,
|
||||
torch::jit::IValue output)
|
||||
: name(name), inputs(std::move(inputs)), output(std::move(output)) {}
|
||||
};
|
||||
|
||||
int main(int argc, const char *argv[]) {
|
||||
if (argc != 2) {
|
||||
std::cerr << "usage: test_jit_hooks <path-to-exported-script-module>\n";
|
||||
return -1;
|
||||
}
|
||||
const std::string path_to_exported_script_module = argv[1];
|
||||
std::cout << "path to exported module:" << path_to_exported_script_module
|
||||
<< std::endl;
|
||||
std::cout << "Tesing JIT Hooks in CPP" << std::endl;
|
||||
|
||||
// Note: Modules loaded in this file are produced in /test/jit_hooks/model.py
|
||||
|
||||
std::vector<HooksTestCase> test_cases = {
|
||||
HooksTestCase("test_submodule_multiple_hooks_single_input",
|
||||
{torch::jit::IValue("a")},
|
||||
"pre_hook_override_name2_inner_mod_fwh1"),
|
||||
HooksTestCase("test_submodule_hook_return_nothing",
|
||||
{torch::jit::IValue("a")}, "a_outermod_inner_mod"),
|
||||
HooksTestCase("test_submodule_same_hook_repeated",
|
||||
{torch::jit::IValue("a")},
|
||||
"a_outermod_ph_ph_inner_mod_fh_fh"),
|
||||
HooksTestCase("test_submodule_forward_single_input",
|
||||
{torch::jit::IValue("a")},
|
||||
"pre_hook_override_name_inner_mod"),
|
||||
HooksTestCase(
|
||||
"test_submodule_multiple_hooks_multiple_inputs",
|
||||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
|
||||
std::tuple<torch::List<std::string>, std::string>(
|
||||
{"pre_hook_override_name", "inner_mod_name"},
|
||||
"pre_hook_override2_fh1_fh2")),
|
||||
HooksTestCase(
|
||||
"test_submodule_forward_multiple_inputs",
|
||||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
|
||||
std::tuple<torch::List<std::string>, std::string>(
|
||||
{"pre_hook_override_name", "inner_mod_name"},
|
||||
"pre_hook_override_fh")),
|
||||
HooksTestCase("test_module_forward_single_input",
|
||||
{torch::jit::IValue("a")},
|
||||
"pre_hook_override_name_outermod_inner_mod_fh"),
|
||||
HooksTestCase("test_module_multiple_hooks_single_input",
|
||||
{torch::jit::IValue("a")},
|
||||
"pre_hook_override_name2_outermod_inner_mod_fh1_fh2"),
|
||||
HooksTestCase("test_module_hook_return_nothing",
|
||||
{torch::jit::IValue("a")}, "a_outermod_inner_mod"),
|
||||
HooksTestCase("test_module_same_hook_repeated", {torch::jit::IValue("a")},
|
||||
"a_ph_ph_outermod_inner_mod_fh_fh"),
|
||||
HooksTestCase(
|
||||
"test_module_forward_multiple_inputs",
|
||||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
|
||||
std::tuple<torch::List<std::string>, std::string>(
|
||||
{"pre_hook_override_name", "outer_mod_name", "inner_mod_name"},
|
||||
"pre_hook_override_fh")),
|
||||
HooksTestCase(
|
||||
"test_module_multiple_hooks_multiple_inputs",
|
||||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
|
||||
std::tuple<torch::List<std::string>, std::string>(
|
||||
{"pre_hook_override_name2", "outer_mod_name", "inner_mod_name"},
|
||||
"pre_hook_override_fh1_fh2")),
|
||||
HooksTestCase("test_module_no_forward_input", {}, torch::jit::IValue()),
|
||||
HooksTestCase("test_forward_tuple_input", {std::tuple<int>(11)},
|
||||
{std::tuple<int>(11)}),
|
||||
};
|
||||
|
||||
for (HooksTestCase &test_case : test_cases) {
|
||||
std::cout << "testing: " << test_case.name << std::endl;
|
||||
torch::jit::Module module = torch::jit::load(
|
||||
path_to_exported_script_module + "_" + test_case.name + ".pt");
|
||||
torch::jit::IValue output = module(test_case.inputs);
|
||||
std::cout << "----- module's output: " << output << std::endl;
|
||||
std::cout << "----- expected output: " << test_case.output << std::endl;
|
||||
AT_ASSERT(output == test_case.output);
|
||||
}
|
||||
|
||||
// special test cases that don't call the imported module directly
|
||||
test_module_forward_invocation_no_hooks_run(path_to_exported_script_module);
|
||||
test_submodule_called_directly_with_hooks(path_to_exported_script_module);
|
||||
|
||||
std::cout << "JIT CPP Hooks okay!" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user