[PyTorch Edge][Version] Fix torchscript model after backport (#58892)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58892

The torchscript model after backport misses the `constants` archive. Add it back, and extend the unit test to run torchscript part.
ghstack-source-id: 129853819

Test Plan:
```
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit
- LiteInterpreterTest.BackPortByteCodeModelAllVersions'
```

Reviewed By: raziel, iseeyuan

Differential Revision: D28664507

fbshipit-source-id: 5f98723231cc64ed203c062ee6f00d8adbdccf77
This commit is contained in:
Chen Lai
2021-05-25 15:33:50 -07:00
committed by Facebook GitHub Bot
parent fb120493b1
commit 60af6e928a
2 changed files with 33 additions and 7 deletions

View File

@ -624,6 +624,34 @@ TEST(LiteInterpreterTest, GetByteCodeVersion) {
}
namespace {
void compareModelOutput(
const std::vector<IValue>& actual_result_list,
const std::vector<Tensor>& expect_result_list) {
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
AT_ASSERT(
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
}
void runAndCheckTorchScriptModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);
// Load and run the backport model, then compare the result with expect
// result
Module m_mobile = load(input_model_stream);
auto actual_result = m_mobile.forward(input_data);
std::vector<IValue> actual_result_list = actual_result.toTuple()->elements();
compareModelOutput(actual_result_list, expect_result_list);
}
void runAndCheckBytecodeModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
@ -634,16 +662,12 @@ void runAndCheckBytecodeModel(
// Load and run the backport model, then compare the result with expect
// result
mobile::Module m_mobile = _load_for_mobile(input_model_stream);
Module m_mobile = load(input_model_stream);
auto actual_result = m_mobile.forward(input_data);
std::vector<IValue> actual_result_list = actual_result.toTuple()->elements();
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
AT_ASSERT(
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
compareModelOutput(actual_result_list, expect_result_list);
}
void backportAllVersionCheck(
@ -676,6 +700,8 @@ void backportAllVersionCheck(
// result
runAndCheckBytecodeModel(
iss, input_data, expect_result_list, current_to_version);
runAndCheckTorchScriptModel(
iss, input_data, expect_result_list, current_to_version);
current_to_version--;
}