mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12714 This is a short change to enable c10 namespace in caffe2. We did not enable it before due to gflags global variable confusion, but it should have been mostly cleaned now. Right now, the plan on record is that namespace caffe2 and namespace aten will fully be supersets of namespace c10. Most of the diff is codemod, and only two places of non-codemod is in caffe2/core/common.h, where ``` using namespace c10; ``` is added, and in Flags.h, where instead of creating aliasing variables in c10 namespace, we directly put it in the global namespace to match gflags (and same behavior if gflags is not being built with). Reviewed By: dzhulgakov Differential Revision: D10390486 fbshipit-source-id: 5e2df730e28e29a052f513bddc558d9f78a23b9b
147 lines
4.9 KiB
C++
147 lines
4.9 KiB
C++
/**
|
|
* Copyright (c) 2016-present, Facebook, Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
//
|
|
// This script converts the CIFAR dataset to the leveldb format used
|
|
// by caffe to perform classification.
|
|
// Usage:
|
|
// convert_cifar_data input_folder output_db_file
|
|
// The CIFAR dataset could be downloaded at
|
|
// http://www.cs.toronto.edu/~kriz/cifar.html
|
|
|
|
#include <array>
|
|
#include <fstream> // NOLINT(readability/streams)
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/db.h"
|
|
#include "caffe2/core/init.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "caffe2/core/logging.h"
|
|
|
|
C10_DEFINE_string(input_folder, "", "The input folder name.");
|
|
C10_DEFINE_string(output_train_db_name, "", "The output training db name.");
|
|
C10_DEFINE_string(output_test_db_name, "", "The output testing db name.");
|
|
C10_DEFINE_string(db, "leveldb", "The db type.");
|
|
C10_DEFINE_bool(
|
|
is_cifar100,
|
|
false,
|
|
"If set, convert cifar100. Otherwise do cifar10.");
|
|
|
|
namespace caffe2 {
|
|
|
|
using std::stringstream;
|
|
|
|
const int kCIFARSize = 32;
|
|
const int kCIFARImageNBytes = kCIFARSize * kCIFARSize * 3;
|
|
const int kCIFAR10BatchSize = 10000;
|
|
const int kCIFAR10TestDataSize = 10000;
|
|
const int kCIFAR10TrainBatches = 5;
|
|
|
|
const int kCIFAR100TrainDataSize = 50000;
|
|
const int kCIFAR100TestDataSize = 10000;
|
|
|
|
void ReadImage(std::ifstream* file, int* label, char* buffer) {
|
|
char label_char;
|
|
if (FLAGS_is_cifar100) {
|
|
// Skip the coarse label.
|
|
file->read(&label_char, 1);
|
|
}
|
|
file->read(&label_char, 1);
|
|
*label = label_char;
|
|
// Yes, there are better ways to do it, like in-place swap... but I am too
|
|
// lazy so let's just write it in a memory-wasteful way.
|
|
std::array<char, kCIFARImageNBytes> channel_first_storage;
|
|
file->read(channel_first_storage.data(), kCIFARImageNBytes);
|
|
for (int c = 0; c < 3; ++c) {
|
|
for (int i = 0; i < kCIFARSize * kCIFARSize; ++i) {
|
|
buffer[i * 3 + c] =
|
|
channel_first_storage[c * kCIFARSize * kCIFARSize + i];
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
void WriteToDB(const string& filename, const int num_items,
|
|
const int& offset, db::DB* db) {
|
|
TensorProtos protos;
|
|
TensorProto* data = protos.add_protos();
|
|
TensorProto* label = protos.add_protos();
|
|
data->set_data_type(TensorProto::BYTE);
|
|
data->add_dims(kCIFARSize);
|
|
data->add_dims(kCIFARSize);
|
|
data->add_dims(3);
|
|
label->set_data_type(TensorProto::INT32);
|
|
label->add_dims(1);
|
|
label->add_int32_data(0);
|
|
|
|
LOG(INFO) << "Converting file " << filename;
|
|
std::ifstream data_file(filename.c_str(),
|
|
std::ios::in | std::ios::binary);
|
|
CAFFE_ENFORCE(data_file, "Unable to open file ", filename);
|
|
char str_buffer[kCIFARImageNBytes];
|
|
int label_value;
|
|
string serialized_protos;
|
|
std::unique_ptr<db::Transaction> transaction(db->NewTransaction());
|
|
for (int itemid = 0; itemid < num_items; ++itemid) {
|
|
ReadImage(&data_file, &label_value, str_buffer);
|
|
data->set_byte_data(str_buffer, kCIFARImageNBytes);
|
|
label->set_int32_data(0, label_value);
|
|
protos.SerializeToString(&serialized_protos);
|
|
snprintf(str_buffer, kCIFARImageNBytes, "%05d",
|
|
offset + itemid);
|
|
transaction->Put(string(str_buffer), serialized_protos);
|
|
}
|
|
}
|
|
|
|
void ConvertCIFAR() {
|
|
std::unique_ptr<db::DB> train_db(
|
|
db::CreateDB(FLAGS_db, FLAGS_output_train_db_name, db::NEW));
|
|
std::unique_ptr<db::DB> test_db(
|
|
db::CreateDB(FLAGS_db, FLAGS_output_test_db_name, db::NEW));
|
|
|
|
if (!FLAGS_is_cifar100) {
|
|
// This is cifar 10.
|
|
for (int fileid = 0; fileid < kCIFAR10TrainBatches; ++fileid) {
|
|
stringstream train_file;
|
|
train_file << FLAGS_input_folder << "/data_batch_" << fileid + 1
|
|
<< ".bin";
|
|
WriteToDB(train_file.str(), kCIFAR10BatchSize,
|
|
fileid * kCIFAR10BatchSize, train_db.get());
|
|
}
|
|
stringstream test_file;
|
|
test_file << FLAGS_input_folder << "/test_batch.bin";
|
|
WriteToDB(test_file.str(), kCIFAR10TestDataSize, 0, test_db.get());
|
|
} else {
|
|
// This is cifar 100.
|
|
stringstream train_file;
|
|
train_file << FLAGS_input_folder << "/train.bin";
|
|
WriteToDB(train_file.str(), kCIFAR100TrainDataSize, 0, train_db.get());
|
|
stringstream test_file;
|
|
test_file << FLAGS_input_folder << "/test.bin";
|
|
WriteToDB(test_file.str(), kCIFAR100TestDataSize, 0, test_db.get());
|
|
}
|
|
}
|
|
|
|
} // namespace caffe2
|
|
|
|
int main(int argc, char** argv) {
|
|
caffe2::GlobalInit(&argc, &argv);
|
|
caffe2::ConvertCIFAR();
|
|
return 0;
|
|
}
|