Back out "[pytorch][PR] Add ability for a mobile::Module to save as flatbuffer" (#69796)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69796

(Note: this ignores all push blocking failures!)

Test Plan: External CI + Sandcastle

Reviewed By: zhxchen17

Differential Revision: D33032671

fbshipit-source-id: dbf6690e960e25d6a5f19043cbe792add2acd7ef
This commit is contained in:
Yanan Cao
2021-12-10 21:22:38 -08:00
committed by Facebook GitHub Bot
parent 3906f8247a
commit 17f3179d60
29 changed files with 19 additions and 2348 deletions

View File

@ -2,21 +2,13 @@
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/torch.h>
// Tests go in torch::jit
namespace torch {
namespace jit {
mobile::Module load_mobile_module(void* data, size_t) {
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
return initialize_mobile_module(flatbuffer_module);
}
TEST(BackendTest, ToBackend) {
Module m("m");
m.define(R"(
@ -149,11 +141,6 @@ TEST(BackendTest, TestCompiler) {
auto mlm = _load_for_mobile(ss);
auto mres = mlm.forward(inputs);
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(ref.toTensor()));
}
TEST(BackendTest, TestComposite) {
@ -196,12 +183,8 @@ TEST(BackendTest, TestComposite) {
c._save_for_mobile(ss);
auto mc = _load_for_mobile(ss);
auto res_mobile = mc.forward(inputs);
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
auto buff = save_mobile_module_to_bytes(mc);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
}
Module getCompositeModuleWithSameNameSubModules() {
@ -258,11 +241,6 @@ TEST(BackendTest, TestCompositeWithSetStates) {
auto mc = _load_for_mobile(ss);
auto res_mobile = mc.forward(inputs);
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
auto buff = save_mobile_module_to_bytes(mc);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
}
TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
@ -278,11 +256,6 @@ TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
auto mc = _load_for_mobile(ss);
auto res_mobile = mc.forward(inputs);
auto buff = save_mobile_module_to_bytes(mc);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(res_mobile.toTensor()));
// check if the methods names are always the same
// by reloading the script module and saving it back as mobile
// The below checks ensure that the names of Methods
@ -381,13 +354,6 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
mlm2.forward(inputs);
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
@ -448,12 +414,6 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}
TEST(
@ -552,13 +512,7 @@ Traceback of TorchScript (most recent call last):
return x + y
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
@ -640,11 +594,6 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(c_loaded);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}
TEST(
@ -772,11 +721,6 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(c_loaded);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}
} // namespace jit