mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12848 Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call SerializeAsString_EnforceCheck so that the return value is checked and can throw an exception if failing. Most of the affected code was called from classes derived from BlobSerializeBase. Didn't touch most tests and ENFORCE calls because they usually do checks anyway. Original commit changeset: c0760e73ecc7 Reviewed By: dzhulgakov Differential Revision: D10453456 fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
91 lines
3.3 KiB
C++
91 lines
3.3 KiB
C++
/**
|
|
* Copyright (c) 2016-present, Facebook, Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "caffe2/core/db.h"
|
|
#include "caffe2/core/init.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "caffe2/proto/caffe2_legacy.pb.h"
|
|
#include "caffe2/core/logging.h"
|
|
|
|
C10_DEFINE_string(input_db, "", "The input db.");
|
|
C10_DEFINE_string(input_db_type, "", "The input db type.");
|
|
C10_DEFINE_string(output_db, "", "The output db.");
|
|
C10_DEFINE_string(output_db_type, "", "The output db type.");
|
|
C10_DEFINE_int(batch_size, 1000, "The write batch size.");
|
|
|
|
using caffe2::db::Cursor;
|
|
using caffe2::db::DB;
|
|
using caffe2::db::Transaction;
|
|
using caffe2::CaffeDatum;
|
|
using caffe2::TensorProto;
|
|
using caffe2::TensorProtos;
|
|
|
|
int main(int argc, char** argv) {
|
|
caffe2::GlobalInit(&argc, &argv);
|
|
|
|
std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
|
|
FLAGS_input_db_type, FLAGS_input_db, caffe2::db::READ));
|
|
std::unique_ptr<DB> out_db(caffe2::db::CreateDB(
|
|
FLAGS_output_db_type, FLAGS_output_db, caffe2::db::NEW));
|
|
std::unique_ptr<Cursor> cursor(in_db->NewCursor());
|
|
std::unique_ptr<Transaction> transaction(out_db->NewTransaction());
|
|
int count = 0;
|
|
for (; cursor->Valid(); cursor->Next()) {
|
|
CaffeDatum datum;
|
|
CAFFE_ENFORCE(datum.ParseFromString(cursor->value()));
|
|
TensorProtos protos;
|
|
TensorProto* data = protos.add_protos();
|
|
TensorProto* label = protos.add_protos();
|
|
label->set_data_type(TensorProto::INT32);
|
|
label->add_dims(1);
|
|
label->add_int32_data(datum.label());
|
|
if (datum.encoded()) {
|
|
// This is an encoded image. we will copy over the data directly.
|
|
data->set_data_type(TensorProto::STRING);
|
|
data->add_dims(1);
|
|
data->add_string_data(datum.data());
|
|
} else {
|
|
// float data not supported right now.
|
|
CAFFE_ENFORCE_EQ(datum.float_data_size(), 0);
|
|
std::vector<char> buffer_vec(datum.data().size());
|
|
char* buffer = buffer_vec.data();
|
|
// swap order from CHW to HWC
|
|
int channels = datum.channels();
|
|
int size = datum.height() * datum.width();
|
|
CAFFE_ENFORCE_EQ(datum.data().size(), channels * size);
|
|
for (int c = 0; c < channels; ++c) {
|
|
char* dst = buffer + c;
|
|
const char* src = datum.data().c_str() + c * size;
|
|
for (int n = 0; n < size; ++n) {
|
|
dst[n*channels] = src[n];
|
|
}
|
|
}
|
|
data->set_data_type(TensorProto::BYTE);
|
|
data->add_dims(datum.height());
|
|
data->add_dims(datum.width());
|
|
data->add_dims(datum.channels());
|
|
data->set_byte_data(buffer, datum.data().size());
|
|
}
|
|
transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos));
|
|
if (++count % FLAGS_batch_size == 0) {
|
|
transaction->Commit();
|
|
LOG(INFO) << "Converted " << count << " items so far.";
|
|
}
|
|
}
|
|
LOG(INFO) << "A total of " << count << " items processed.";
|
|
return 0;
|
|
}
|