mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
[PyTorch Edge][Version] Fix torchscript model after backport (#58892)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58892 The torchscript model after backport misses the `constants` archive. Add it back, and extend the unit test to run torchscript part. ghstack-source-id: 129853819 Test Plan: ``` buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.BackPortByteCodeModelAllVersions' ``` Reviewed By: raziel, iseeyuan Differential Revision: D28664507 fbshipit-source-id: 5f98723231cc64ed203c062ee6f00d8adbdccf77
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fb120493b1
commit
60af6e928a
@ -624,6 +624,34 @@ TEST(LiteInterpreterTest, GetByteCodeVersion) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void compareModelOutput(
|
||||
const std::vector<IValue>& actual_result_list,
|
||||
const std::vector<Tensor>& expect_result_list) {
|
||||
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
|
||||
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
|
||||
AT_ASSERT(
|
||||
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
|
||||
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
|
||||
}
|
||||
|
||||
void runAndCheckTorchScriptModel(
|
||||
std::stringstream& input_model_stream,
|
||||
const std::vector<IValue>& input_data,
|
||||
const std::vector<Tensor>& expect_result_list,
|
||||
const int64_t expect_version) {
|
||||
auto actual_version = _get_model_bytecode_version(input_model_stream);
|
||||
AT_ASSERT(actual_version == expect_version);
|
||||
|
||||
// Load and run the backport model, then compare the result with expect
|
||||
// result
|
||||
Module m_mobile = load(input_model_stream);
|
||||
|
||||
auto actual_result = m_mobile.forward(input_data);
|
||||
std::vector<IValue> actual_result_list = actual_result.toTuple()->elements();
|
||||
compareModelOutput(actual_result_list, expect_result_list);
|
||||
}
|
||||
|
||||
void runAndCheckBytecodeModel(
|
||||
std::stringstream& input_model_stream,
|
||||
const std::vector<IValue>& input_data,
|
||||
@ -634,16 +662,12 @@ void runAndCheckBytecodeModel(
|
||||
|
||||
// Load and run the backport model, then compare the result with expect
|
||||
// result
|
||||
mobile::Module m_mobile = _load_for_mobile(input_model_stream);
|
||||
Module m_mobile = load(input_model_stream);
|
||||
|
||||
auto actual_result = m_mobile.forward(input_data);
|
||||
std::vector<IValue> actual_result_list = actual_result.toTuple()->elements();
|
||||
|
||||
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
|
||||
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
|
||||
AT_ASSERT(
|
||||
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
|
||||
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
|
||||
compareModelOutput(actual_result_list, expect_result_list);
|
||||
}
|
||||
|
||||
void backportAllVersionCheck(
|
||||
@ -676,6 +700,8 @@ void backportAllVersionCheck(
|
||||
// result
|
||||
runAndCheckBytecodeModel(
|
||||
iss, input_data, expect_result_list, current_to_version);
|
||||
runAndCheckTorchScriptModel(
|
||||
iss, input_data, expect_result_list, current_to_version);
|
||||
|
||||
current_to_version--;
|
||||
}
|
||||
|
Reference in New Issue
Block a user