mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable test_api IMethodTest in OSS (#62521)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62521 This diff did the following few things to enable the tests: 1. Exposed IMethod as TORCH_API. 2. Linked torch_deploy to test_api if USE_DEPLOY == 1. Test Plan: ./build/bin/test_api --gtest_filter=IMethodTest.* To be noted, one needs to run `python torch/csrc/deploy/example/generate_examples.py` before the above command. Reviewed By: ezyang Differential Revision: D30055372 Pulled By: alanwaketan fbshipit-source-id: 50eb3689cf84ed0f48be58cd109afcf61ecca508
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a749180e4e
commit
4b68801c69
@ -41,6 +41,10 @@ set(TORCH_API_TEST_SOURCES
|
||||
${TORCH_API_TEST_DIR}/grad_mode.cpp
|
||||
)
|
||||
|
||||
if(USE_DEPLOY)
|
||||
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/imethod.cpp)
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp)
|
||||
endif()
|
||||
@ -59,6 +63,10 @@ if(USE_CUDA)
|
||||
target_compile_definitions(test_api PRIVATE "USE_CUDA")
|
||||
endif()
|
||||
|
||||
if(USE_DEPLOY)
|
||||
target_link_libraries(test_api PRIVATE torch_deploy)
|
||||
endif()
|
||||
|
||||
# Workaround for https://github.com/pytorch/pytorch/issues/40941
|
||||
if(USE_OPENMP AND CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0.0))
|
||||
# Compiling transformer.cpp or pow_test.cpp with -O2+ and both -fuse-openmp and -faligned-newout any optimization
|
||||
|
@ -8,30 +8,38 @@
|
||||
using namespace ::testing;
|
||||
using namespace caffe2;
|
||||
|
||||
// TODO(T96218435): Enable the following tests in OSS.
|
||||
const char* simple = "torch/csrc/deploy/example/generated/simple";
|
||||
const char* simpleJit = "torch/csrc/deploy/example/generated/simple_jit";
|
||||
|
||||
// TODO(jwtan): Try unifying cmake and buck for getting the path.
|
||||
const char* path(const char* envname, const char* path) {
|
||||
const char* env = getenv(envname);
|
||||
return env ? env : path;
|
||||
}
|
||||
|
||||
TEST(IMethodTest, CallMethod) {
|
||||
auto script_model = torch::jit::load(getenv("SIMPLE_JIT"));
|
||||
auto script_method = script_model.get_method("forward");
|
||||
auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit));
|
||||
auto scriptMethod = scriptModel.get_method("forward");
|
||||
|
||||
torch::deploy::InterpreterManager manager(3);
|
||||
torch::deploy::Package p = manager.load_package(getenv("SIMPLE"));
|
||||
auto py_model = p.load_pickle("model", "model.pkl");
|
||||
torch::deploy::PythonMethodWrapper py_method(py_model, "forward");
|
||||
torch::deploy::Package package = manager.load_package(path("SIMPLE", simple));
|
||||
auto pyModel = package.load_pickle("model", "model.pkl");
|
||||
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
|
||||
|
||||
auto input = torch::ones({10, 20});
|
||||
auto output_py = py_method({input});
|
||||
auto output_script = script_method({input});
|
||||
EXPECT_TRUE(output_py.isTensor());
|
||||
EXPECT_TRUE(output_script.isTensor());
|
||||
auto output_py_tensor = output_py.toTensor();
|
||||
auto output_script_tensor = output_script.toTensor();
|
||||
auto outputPy = pyMethod({input});
|
||||
auto outputScript = scriptMethod({input});
|
||||
EXPECT_TRUE(outputPy.isTensor());
|
||||
EXPECT_TRUE(outputScript.isTensor());
|
||||
auto outputPyTensor = outputPy.toTensor();
|
||||
auto outputScriptTensor = outputScript.toTensor();
|
||||
|
||||
EXPECT_TRUE(output_py_tensor.equal(output_script_tensor));
|
||||
EXPECT_EQ(output_py_tensor.numel(), 200);
|
||||
EXPECT_TRUE(outputPyTensor.equal(outputScriptTensor));
|
||||
EXPECT_EQ(outputPyTensor.numel(), 200);
|
||||
}
|
||||
|
||||
TEST(IMethodTest, GetArgumentNames) {
|
||||
auto scriptModel = torch::jit::load(getenv("SIMPLE_JIT"));
|
||||
auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit));
|
||||
auto scriptMethod = scriptModel.get_method("forward");
|
||||
|
||||
auto& scriptNames = scriptMethod.getArgumentNames();
|
||||
@ -39,7 +47,7 @@ TEST(IMethodTest, GetArgumentNames) {
|
||||
EXPECT_STREQ(scriptNames[0].c_str(), "input");
|
||||
|
||||
torch::deploy::InterpreterManager manager(3);
|
||||
torch::deploy::Package package = manager.load_package(getenv("SIMPLE"));
|
||||
torch::deploy::Package package = manager.load_package(path("SIMPLE", simple));
|
||||
auto pyModel = package.load_pickle("model", "model.pkl");
|
||||
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
class IMethod {
|
||||
class TORCH_API IMethod {
|
||||
/*
|
||||
IMethod provides a portable interface for torch methods, whether
|
||||
they are backed by torchscript or python/deploy.
|
||||
|
Reference in New Issue
Block a user