Files
pytorch/binaries/make_mnist_db.cc
Yangqing Jia 7d5f7ed270 Using c10 namespace across caffe2. (#12714)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12714

This is a short change to enable c10 namespace in caffe2. We did not enable
it before due to gflags global variable confusion, but it should have been
mostly cleaned now. Right now, the plan on record is that namespace caffe2 and
namespace aten will fully be supersets of namespace c10.

Most of the diff is codemod, and only two places of non-codemod is in caffe2/core/common.h, where

```
using namespace c10;
```

is added, and in Flags.h, where instead of creating aliasing variables in c10 namespace, we directly put it in the global namespace to match gflags (and same behavior if gflags is not being built with).

Reviewed By: dzhulgakov

Differential Revision: D10390486

fbshipit-source-id: 5e2df730e28e29a052f513bddc558d9f78a23b9b
2018-10-17 12:57:19 -07:00

147 lines
4.8 KiB
C++

/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This script converts the MNIST dataset to leveldb.
// The MNIST dataset could be downloaded at
// http://yann.lecun.com/exdb/mnist/
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include "caffe2/core/common.h"
#include "caffe2/core/db.h"
#include "caffe2/core/init.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/core/logging.h"
C10_DEFINE_string(image_file, "", "The input image file name.");
C10_DEFINE_string(label_file, "", "The label file name.");
C10_DEFINE_string(output_file, "", "The output db name.");
C10_DEFINE_string(db, "leveldb", "The db type.");
C10_DEFINE_int(
data_limit,
-1,
"If set, only output this number of data points.");
C10_DEFINE_bool(
channel_first,
false,
"If set, write the data as channel-first (CHW order) as the old "
"Caffe does.");
namespace caffe2 {
uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
void convert_dataset(const char* image_filename, const char* label_filename,
const char* db_path, const int data_limit) {
// Open files
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
CAFFE_ENFORCE(image_file, "Unable to open file ", image_filename);
CAFFE_ENFORCE(label_file, "Unable to open file ", label_filename);
// Read the magic and the meta data
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
if (magic == 529205256) {
LOG(FATAL) <<
"It seems that you forgot to unzip the mnist dataset. You should "
"first unzip them using e.g. gunzip on Linux.";
}
CAFFE_ENFORCE_EQ(magic, 2051, "Incorrect image file magic.");
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CAFFE_ENFORCE_EQ(magic, 2049, "Incorrect label file magic.");
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CAFFE_ENFORCE_EQ(num_items, num_labels);
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);
// leveldb
std::unique_ptr<db::DB> mnist_db(db::CreateDB(FLAGS_db, db_path, db::NEW));
std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
// Storing to db
char label_value;
std::vector<char> pixels(rows * cols);
int count = 0;
const int kMaxKeyLength = 10;
char key_cstr[kMaxKeyLength];
string value;
TensorProtos protos;
TensorProto* data = protos.add_protos();
TensorProto* label = protos.add_protos();
data->set_data_type(TensorProto::BYTE);
if (FLAGS_channel_first) {
data->add_dims(1);
data->add_dims(rows);
data->add_dims(cols);
} else {
data->add_dims(rows);
data->add_dims(cols);
data->add_dims(1);
}
label->set_data_type(TensorProto::INT32);
label->add_int32_data(0);
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
for (int item_id = 0; item_id < num_items; ++item_id) {
image_file.read(pixels.data(), rows * cols);
label_file.read(&label_value, 1);
for (int i = 0; i < rows * cols; ++i) {
data->set_byte_data(pixels.data(), rows * cols);
}
label->set_int32_data(0, static_cast<int>(label_value));
snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
protos.SerializeToString(&value);
string keystr(key_cstr);
// Put in db
transaction->Put(keystr, value);
if (++count % 1000 == 0) {
transaction->Commit();
}
if (data_limit > 0 && count == data_limit) {
LOG(INFO) << "Reached data limit of " << data_limit << ", stop.";
break;
}
}
}
} // namespace caffe2
int main(int argc, char** argv) {
caffe2::GlobalInit(&argc, &argv);
caffe2::convert_dataset(
FLAGS_image_file.c_str(),
FLAGS_label_file.c_str(),
FLAGS_output_file.c_str(),
FLAGS_data_limit);
return 0;
}