mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add the support of feature store example in pytorch model in fblearner (#20040)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20040 Add the support of feature store example in fblearner pytorch predictor, end to end Reviewed By: dzhulgakov Differential Revision: D15177897 fbshipit-source-id: 0f6df8b064eb9844fc9ddae61e978d6574c22916
This commit is contained in:
committed by
Facebook Github Bot
parent
9fbce974c9
commit
af6eea9391
@ -165,6 +165,18 @@ static std::string getPadding(size_t cursor, const std::string& filename, size_t
|
||||
return buf;
|
||||
}
|
||||
|
||||
bool PyTorchStreamReader::hasFile(const std::string& name) {
|
||||
std::stringstream ss;
|
||||
ss << archive_name_ << "/" << name;
|
||||
mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
|
||||
bool result = ar_->m_last_error != MZ_ZIP_FILE_NOT_FOUND;
|
||||
if (!result) {
|
||||
ar_->m_last_error = MZ_ZIP_NO_ERROR;
|
||||
}
|
||||
valid("attempting to locate file");
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getFileID(const std::string& name) {
|
||||
std::stringstream ss;
|
||||
ss << archive_name_ << "/" << name;
|
||||
|
@ -106,8 +106,8 @@ class CAFFE2_API PyTorchStreamReader final {
|
||||
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||
|
||||
size_t getRecordOffset(const std::string& name);
|
||||
bool hasFile(const std::string& name);
|
||||
|
||||
~PyTorchStreamReader();
|
||||
|
||||
|
@ -39,6 +39,9 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
|
||||
// read records through readers
|
||||
PyTorchStreamReader reader(&iss);
|
||||
ASSERT_TRUE(reader.hasFile("key1"));
|
||||
ASSERT_TRUE(reader.hasFile("key2"));
|
||||
ASSERT_FALSE(reader.hasFile("key2000"));
|
||||
at::DataPtr data_ptr;
|
||||
int64_t size;
|
||||
std::tie(data_ptr, size) = reader.getRecord("key1");
|
||||
@ -48,7 +51,6 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
|
||||
ASSERT_EQ(off1 % kFieldAlignment, 0);
|
||||
|
||||
|
||||
std::tie(data_ptr, size) = reader.getRecord("key2");
|
||||
size_t off2 = reader.getRecordOffset("key2");
|
||||
ASSERT_EQ(off2 % kFieldAlignment, 0);
|
||||
|
@ -134,11 +134,13 @@ void ScriptModuleDeserializer::deserialize(
|
||||
// Load extra files.
|
||||
for (const auto& kv : extra_files) {
|
||||
const std::string& key = "extra/" + kv.first;
|
||||
at::DataPtr meta_ptr;
|
||||
size_t meta_size;
|
||||
std::tie(meta_ptr, meta_size) = reader_.getRecord(key);
|
||||
extra_files[kv.first] =
|
||||
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
|
||||
if (reader_.hasFile(key)) {
|
||||
at::DataPtr meta_ptr;
|
||||
size_t meta_size;
|
||||
std::tie(meta_ptr, meta_size) = reader_.getRecord(key);
|
||||
extra_files[kv.first] =
|
||||
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
|
||||
}
|
||||
}
|
||||
|
||||
loadTensorTable(&model_def);
|
||||
|
Reference in New Issue
Block a user