Add the support of feature store example in pytorch model in fblearner (#20040)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20040

Add the support of feature store example in fblearner pytorch predictor, end to end

Reviewed By: dzhulgakov

Differential Revision: D15177897

fbshipit-source-id: 0f6df8b064eb9844fc9ddae61e978d6574c22916
This commit is contained in:
Lu Fang
2019-05-20 12:49:56 -07:00
committed by Facebook Github Bot
parent 9fbce974c9
commit af6eea9391
4 changed files with 23 additions and 7 deletions

View File

@ -165,6 +165,18 @@ static std::string getPadding(size_t cursor, const std::string& filename, size_t
return buf;
}
bool PyTorchStreamReader::hasFile(const std::string& name) {
std::stringstream ss;
ss << archive_name_ << "/" << name;
mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
bool result = ar_->m_last_error != MZ_ZIP_FILE_NOT_FOUND;
if (!result) {
ar_->m_last_error = MZ_ZIP_NO_ERROR;
}
valid("attempting to locate file");
return result;
}
size_t PyTorchStreamReader::getFileID(const std::string& name) {
std::stringstream ss;
ss << archive_name_ << "/" << name;

View File

@ -106,8 +106,8 @@ class CAFFE2_API PyTorchStreamReader final {
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
size_t getRecordOffset(const std::string& name);
bool hasFile(const std::string& name);
~PyTorchStreamReader();

View File

@ -39,6 +39,9 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
// read records through readers
PyTorchStreamReader reader(&iss);
ASSERT_TRUE(reader.hasFile("key1"));
ASSERT_TRUE(reader.hasFile("key2"));
ASSERT_FALSE(reader.hasFile("key2000"));
at::DataPtr data_ptr;
int64_t size;
std::tie(data_ptr, size) = reader.getRecord("key1");
@ -48,7 +51,6 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
ASSERT_EQ(off1 % kFieldAlignment, 0);
std::tie(data_ptr, size) = reader.getRecord("key2");
size_t off2 = reader.getRecordOffset("key2");
ASSERT_EQ(off2 % kFieldAlignment, 0);

View File

@ -134,11 +134,13 @@ void ScriptModuleDeserializer::deserialize(
// Load extra files.
for (const auto& kv : extra_files) {
const std::string& key = "extra/" + kv.first;
at::DataPtr meta_ptr;
size_t meta_size;
std::tie(meta_ptr, meta_size) = reader_.getRecord(key);
extra_files[kv.first] =
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
if (reader_.hasFile(key)) {
at::DataPtr meta_ptr;
size_t meta_size;
std::tie(meta_ptr, meta_size) = reader_.getRecord(key);
extra_files[kv.first] =
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
}
}
loadTensorTable(&model_def);