mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Back out "[pytorch][PR] Add ability for a mobile::Module to save as flatbuffer" (#69796)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69796 (Note: this ignores all push blocking failures!) Test Plan: External CI + Sandcastle Reviewed By: zhxchen17 Differential Revision: D33032671 fbshipit-source-id: dbf6690e960e25d6a5f19043cbe792add2acd7ef
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3906f8247a
commit
17f3179d60
@ -2,21 +2,13 @@
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// Tests go in torch::jit
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
mobile::Module load_mobile_module(void* data, size_t) {
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
||||
return initialize_mobile_module(flatbuffer_module);
|
||||
}
|
||||
|
||||
TEST(BackendTest, ToBackend) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
@ -149,11 +141,6 @@ TEST(BackendTest, TestCompiler) {
|
||||
auto mlm = _load_for_mobile(ss);
|
||||
auto mres = mlm.forward(inputs);
|
||||
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
||||
|
||||
auto buff = save_mobile_module_to_bytes(mlm);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
auto mres2 = mlm2.forward(inputs);
|
||||
AT_ASSERT(mres2.toTensor().equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestComposite) {
|
||||
@ -196,12 +183,8 @@ TEST(BackendTest, TestComposite) {
|
||||
c._save_for_mobile(ss);
|
||||
auto mc = _load_for_mobile(ss);
|
||||
auto res_mobile = mc.forward(inputs);
|
||||
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
|
||||
|
||||
auto buff = save_mobile_module_to_bytes(mc);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
auto mres2 = mlm2.forward(inputs);
|
||||
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
|
||||
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
|
||||
}
|
||||
|
||||
Module getCompositeModuleWithSameNameSubModules() {
|
||||
@ -258,11 +241,6 @@ TEST(BackendTest, TestCompositeWithSetStates) {
|
||||
auto mc = _load_for_mobile(ss);
|
||||
auto res_mobile = mc.forward(inputs);
|
||||
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
|
||||
|
||||
auto buff = save_mobile_module_to_bytes(mc);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
auto mres2 = mlm2.forward(inputs);
|
||||
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
|
||||
@ -278,11 +256,6 @@ TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
|
||||
auto mc = _load_for_mobile(ss);
|
||||
auto res_mobile = mc.forward(inputs);
|
||||
|
||||
auto buff = save_mobile_module_to_bytes(mc);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
auto mres2 = mlm2.forward(inputs);
|
||||
AT_ASSERT(mres2.toTensor().equal(res_mobile.toTensor()));
|
||||
|
||||
// check if the methods names are always the same
|
||||
// by reloading the script module and saving it back as mobile
|
||||
// The below checks ensure that the names of Methods
|
||||
@ -381,13 +354,6 @@ Traceback of TorchScript (most recent call last):
|
||||
~~~~~ <--- HERE
|
||||
)";
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||
|
||||
/* TODO(add debug info to flatbuffer)
|
||||
auto buff = save_mobile_module_to_bytes(mlm);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
mlm2.forward(inputs);
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||
*/
|
||||
}
|
||||
|
||||
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
|
||||
@ -448,12 +414,6 @@ Traceback of TorchScript (most recent call last):
|
||||
~~~~~ <--- HERE
|
||||
)";
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||
|
||||
/* TODO(add debug info to flatbuffer)
|
||||
auto buff = save_mobile_module_to_bytes(mlm);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||
*/
|
||||
}
|
||||
|
||||
TEST(
|
||||
@ -552,13 +512,7 @@ Traceback of TorchScript (most recent call last):
|
||||
return x + y
|
||||
~~~~~ <--- HERE
|
||||
)";
|
||||
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||
/* TODO(add debug info to flatbuffer)
|
||||
auto buff = save_mobile_module_to_bytes(mlm);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||
*/
|
||||
}
|
||||
|
||||
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
|
||||
@ -640,11 +594,6 @@ Traceback of TorchScript (most recent call last):
|
||||
~~~~~ <--- HERE
|
||||
)";
|
||||
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
|
||||
/* TODO(add debug info to flatbuffer)
|
||||
auto buff = save_mobile_module_to_bytes(c_loaded);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||
*/
|
||||
}
|
||||
|
||||
TEST(
|
||||
@ -772,11 +721,6 @@ Traceback of TorchScript (most recent call last):
|
||||
~~~~~ <--- HERE
|
||||
)";
|
||||
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
|
||||
/* TODO(add debug info to flatbuffer)
|
||||
auto buff = save_mobile_module_to_bytes(c_loaded);
|
||||
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||
*/
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
||||
Reference in New Issue
Block a user