mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Caffe2 module update: move observers as well as binaries. (#2145)
* Caffe2 module update: move observers as well as binaries. * Add threads linkage * Add Threads dependency to public interface
This commit is contained in:
52
binaries/CMakeLists.txt
Normal file
52
binaries/CMakeLists.txt
Normal file
@ -0,0 +1,52 @@
|
||||
caffe2_binary_target("convert_caffe_image_db.cc")
|
||||
caffe2_binary_target("convert_db.cc")
|
||||
caffe2_binary_target("make_cifar_db.cc")
|
||||
caffe2_binary_target("make_mnist_db.cc")
|
||||
caffe2_binary_target("predictor_verifier.cc")
|
||||
caffe2_binary_target("print_registered_core_operators.cc")
|
||||
caffe2_binary_target("run_plan.cc")
|
||||
caffe2_binary_target("speed_benchmark.cc")
|
||||
caffe2_binary_target("split_db.cc")
|
||||
|
||||
caffe2_binary_target("db_throughput.cc")
|
||||
|
||||
if (USE_CUDA)
|
||||
caffe2_binary_target("inspect_gpus.cc")
|
||||
target_link_libraries(inspect_gpus ${CUDA_LIBRARIES})
|
||||
caffe2_binary_target("print_core_object_sizes.cc")
|
||||
|
||||
if (BUILD_TEST)
|
||||
# Core overhead benchmark
|
||||
caffe2_binary_target("core_overhead_benchmark.cc")
|
||||
target_link_libraries(core_overhead_benchmark benchmark ${CUDA_curand_LIBRARY})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (USE_ZMQ)
|
||||
caffe2_binary_target("zmq_feeder.cc")
|
||||
target_link_libraries(zmq_feeder ${ZMQ_LIBRARIES})
|
||||
endif()
|
||||
|
||||
if(USE_MPI)
|
||||
caffe2_binary_target("run_plan_mpi.cc")
|
||||
target_link_libraries(run_plan_mpi ${MPI_CXX_LIBRARIES})
|
||||
endif()
|
||||
|
||||
if (USE_OPENCV AND USE_LEVELDB)
|
||||
caffe2_binary_target("convert_encoded_to_raw_leveldb.cc")
|
||||
target_link_libraries(
|
||||
convert_encoded_to_raw_leveldb
|
||||
${OpenCV_LIBS} ${LevelDB_LIBRARIES} ${Snappy_LIBRARIES})
|
||||
endif()
|
||||
|
||||
if (USE_OPENCV)
|
||||
caffe2_binary_target("make_image_db.cc")
|
||||
target_link_libraries(make_image_db ${OpenCV_LIBS})
|
||||
endif()
|
||||
|
||||
if (USE_OBSERVERS)
|
||||
caffe2_binary_target("caffe2_benchmark.cc")
|
||||
endif()
|
||||
|
||||
# ---[ tutorials
|
||||
caffe2_binary_target("tutorial_blob.cc")
|
239
binaries/caffe2_benchmark.cc
Normal file
239
binaries/caffe2_benchmark.cc
Normal file
@ -0,0 +1,239 @@
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
|
||||
#include "observers/observer_config.h"
|
||||
|
||||
CAFFE2_DEFINE_string(
|
||||
backend,
|
||||
"builtin",
|
||||
"The backend to use when running the model. The allowed "
|
||||
"backend choices are: builtin, default, nnpack, eigen, mkl");
|
||||
CAFFE2_DEFINE_string(
|
||||
init_net,
|
||||
"",
|
||||
"The given net to initialize any parameters.");
|
||||
CAFFE2_DEFINE_string(
|
||||
input,
|
||||
"",
|
||||
"Input that is needed for running the network. If "
|
||||
"multiple input needed, use comma separated string.");
|
||||
CAFFE2_DEFINE_string(
|
||||
input_dims,
|
||||
"",
|
||||
"Alternate to input_files, if all inputs are simple "
|
||||
"float TensorCPUs, specify the dimension using comma "
|
||||
"separated numbers. If multiple input needed, use "
|
||||
"semicolon to separate the dimension of different "
|
||||
"tensors.");
|
||||
CAFFE2_DEFINE_string(
|
||||
input_file,
|
||||
"",
|
||||
"Input file that contain the serialized protobuf for "
|
||||
"the input blobs. If multiple input needed, use comma "
|
||||
"separated string. Must have the same number of items "
|
||||
"as input does.");
|
||||
CAFFE2_DEFINE_string(
|
||||
input_type,
|
||||
"float",
|
||||
"Input type when specifying the input dimension."
|
||||
"The supported types are float, uint8_t.");
|
||||
CAFFE2_DEFINE_int(iter, 10, "The number of iterations to run.");
|
||||
CAFFE2_DEFINE_string(net, "", "The given net to benchmark.");
|
||||
CAFFE2_DEFINE_string(
|
||||
output,
|
||||
"",
|
||||
"Output that should be dumped after the execution "
|
||||
"finishes. If multiple outputs are needed, use comma "
|
||||
"separated string. If you want to dump everything, pass "
|
||||
"'*' as the output value.");
|
||||
CAFFE2_DEFINE_string(
|
||||
output_folder,
|
||||
"",
|
||||
"The folder that the output should be written to. This "
|
||||
"folder must already exist in the file system.");
|
||||
CAFFE2_DEFINE_bool(
|
||||
run_individual,
|
||||
false,
|
||||
"Whether to benchmark individual operators.");
|
||||
CAFFE2_DEFINE_bool(
|
||||
text_output,
|
||||
false,
|
||||
"Whether to write out output in text format for regression purpose.");
|
||||
CAFFE2_DEFINE_int(warmup, 0, "The number of iterations to warm up.");
|
||||
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
using std::vector;
|
||||
|
||||
static void writeTextOutput(
|
||||
caffe2::TensorCPU* tensor,
|
||||
const string& output_prefix,
|
||||
const string& name) {
|
||||
string output_name = output_prefix + "/" + name + ".txt";
|
||||
caffe2::TensorSerializer<caffe2::CPUContext> ser;
|
||||
caffe2::BlobProto blob_proto;
|
||||
ser.Serialize(
|
||||
*tensor, output_name, blob_proto.mutable_tensor(), 0, tensor->size());
|
||||
blob_proto.set_name(output_name);
|
||||
blob_proto.set_type("Tensor");
|
||||
CAFFE_ENFORCE(blob_proto.has_tensor());
|
||||
caffe2::TensorProto tensor_proto = blob_proto.tensor();
|
||||
vector<float> data;
|
||||
switch (tensor_proto.data_type()) {
|
||||
case caffe2::TensorProto::FLOAT: {
|
||||
std::copy(
|
||||
tensor_proto.float_data().begin(),
|
||||
tensor_proto.float_data().end(),
|
||||
std::back_inserter(data));
|
||||
break;
|
||||
}
|
||||
case caffe2::TensorProto::INT32: {
|
||||
std::copy(
|
||||
tensor_proto.int32_data().begin(),
|
||||
tensor_proto.int32_data().end(),
|
||||
std::back_inserter(data));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CAFFE_THROW("Unimplemented Blob type.");
|
||||
}
|
||||
std::ofstream output_file(output_name);
|
||||
std::ostream_iterator<float> output_iterator(output_file, "\n");
|
||||
std::copy(data.begin(), data.end(), output_iterator);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::ShowLogInfoToStderr();
|
||||
unique_ptr<caffe2::Workspace> workspace(new caffe2::Workspace());
|
||||
|
||||
// Run initialization network.
|
||||
caffe2::NetDef init_net_def;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_init_net, &init_net_def));
|
||||
CAFFE_ENFORCE(workspace->RunNetOnce(init_net_def));
|
||||
|
||||
// Load input.
|
||||
if (caffe2::FLAGS_input.size()) {
|
||||
vector<string> input_names = caffe2::split(',', caffe2::FLAGS_input);
|
||||
if (caffe2::FLAGS_input_file.size()) {
|
||||
vector<string> input_files = caffe2::split(',', caffe2::FLAGS_input_file);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_files.size(),
|
||||
"Input name and file should have the same number.");
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
caffe2::BlobProto blob_proto;
|
||||
CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto));
|
||||
workspace->CreateBlob(input_names[i])->Deserialize(blob_proto);
|
||||
}
|
||||
} else if (caffe2::FLAGS_input_dims.size()) {
|
||||
vector<string> input_dims_list =
|
||||
caffe2::split(';', caffe2::FLAGS_input_dims);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_dims_list.size(),
|
||||
"Input name and dims should have the same number of items.");
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
vector<string> input_dims_str = caffe2::split(',', input_dims_list[i]);
|
||||
vector<int> input_dims;
|
||||
for (const string& s : input_dims_str) {
|
||||
input_dims.push_back(caffe2::stoi(s));
|
||||
}
|
||||
if (!workspace->HasBlob(input_names[i])) {
|
||||
workspace->CreateBlob(input_names[i]);
|
||||
}
|
||||
caffe2::TensorCPU* tensor =
|
||||
workspace->GetBlob(input_names[i])->GetMutable<caffe2::TensorCPU>();
|
||||
tensor->Resize(input_dims);
|
||||
if (caffe2::FLAGS_input_type == "float") {
|
||||
tensor->mutable_data<float>();
|
||||
} else {
|
||||
CAFFE_ENFORCE(
|
||||
caffe2::FLAGS_input_type == "uint8_t",
|
||||
"Only supported input types are: float, uint8_t");
|
||||
tensor->mutable_data<uint8_t>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"You requested input tensors, but neither input_file nor "
|
||||
"input_dims is set.");
|
||||
}
|
||||
}
|
||||
|
||||
// Run main network.
|
||||
caffe2::NetDef net_def;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def));
|
||||
if (caffe2::FLAGS_backend != "builtin") {
|
||||
std::string engine = caffe2::FLAGS_backend == "nnpack" ? "NNPACK" :
|
||||
caffe2::FLAGS_backend == "eigen" ? "EIGEN" :
|
||||
caffe2::FLAGS_backend == "mkl" ? "MKLDNN" :
|
||||
caffe2::FLAGS_backend == "default" ? "" : "NONE";
|
||||
CAFFE_ENFORCE(engine != "NONE", "Backend is not supported");
|
||||
for (int i = 0; i < net_def.op_size(); i++) {
|
||||
caffe2::OperatorDef* op_def = net_def.mutable_op(i);
|
||||
op_def->set_engine(engine);
|
||||
}
|
||||
}
|
||||
|
||||
caffe2::NetBase* net = workspace->CreateNet(net_def);
|
||||
CHECK_NOTNULL(net);
|
||||
|
||||
LOG(INFO) << "Starting benchmark.";
|
||||
caffe2::ObserverConfig::initSampleRate(
|
||||
1, 1, 1, caffe2::FLAGS_run_individual, caffe2::FLAGS_warmup);
|
||||
LOG(INFO) << "Running warmup runs.";
|
||||
for (int i = 0; i < caffe2::FLAGS_warmup; ++i) {
|
||||
CAFFE_ENFORCE(net->Run(), "Warmup run ", i, " has failed.");
|
||||
}
|
||||
|
||||
LOG(INFO) << "Main runs.";
|
||||
CAFFE_ENFORCE(
|
||||
caffe2::FLAGS_iter >= 0,
|
||||
"Number of main runs should be non negative, provided ",
|
||||
caffe2::FLAGS_iter,
|
||||
".");
|
||||
for (int i = 0; i < caffe2::FLAGS_iter; ++i) {
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, 0, caffe2::FLAGS_warmup);
|
||||
CAFFE_ENFORCE(net->Run(), "Main run ", i, " has failed.");
|
||||
if (caffe2::FLAGS_run_individual) {
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, 1, caffe2::FLAGS_warmup);
|
||||
CAFFE_ENFORCE(net->Run(), "Main run ", i, " with operator has failed.");
|
||||
}
|
||||
}
|
||||
|
||||
string output_prefix = caffe2::FLAGS_output_folder.size()
|
||||
? caffe2::FLAGS_output_folder + "/"
|
||||
: "";
|
||||
if (caffe2::FLAGS_output.size()) {
|
||||
vector<string> output_names = caffe2::split(',', caffe2::FLAGS_output);
|
||||
if (caffe2::FLAGS_output == "*") {
|
||||
output_names = workspace->Blobs();
|
||||
}
|
||||
for (const string& name : output_names) {
|
||||
CAFFE_ENFORCE(
|
||||
workspace->HasBlob(name),
|
||||
"You requested a non-existing blob: ",
|
||||
name);
|
||||
if (caffe2::FLAGS_text_output) {
|
||||
auto blob = workspace->GetBlob(name)->GetMutable<caffe2::TensorCPU>();
|
||||
writeTextOutput(blob, output_prefix, name);
|
||||
} else {
|
||||
string serialized = workspace->GetBlob(name)->Serialize(name);
|
||||
string output_filename = output_prefix + name;
|
||||
caffe2::WriteStringToFile(serialized, output_filename.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
90
binaries/convert_caffe_image_db.cc
Normal file
90
binaries/convert_caffe_image_db.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(input_db, "", "The input db.");
|
||||
CAFFE2_DEFINE_string(input_db_type, "", "The input db type.");
|
||||
CAFFE2_DEFINE_string(output_db, "", "The output db.");
|
||||
CAFFE2_DEFINE_string(output_db_type, "", "The output db type.");
|
||||
CAFFE2_DEFINE_int(batch_size, 1000, "The write batch size.");
|
||||
|
||||
using caffe2::db::Cursor;
|
||||
using caffe2::db::DB;
|
||||
using caffe2::db::Transaction;
|
||||
using caffe2::TensorProto;
|
||||
using caffe2::TensorProtos;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
|
||||
std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
|
||||
caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db, caffe2::db::READ));
|
||||
std::unique_ptr<DB> out_db(caffe2::db::CreateDB(
|
||||
caffe2::FLAGS_output_db_type, caffe2::FLAGS_output_db, caffe2::db::NEW));
|
||||
std::unique_ptr<Cursor> cursor(in_db->NewCursor());
|
||||
std::unique_ptr<Transaction> transaction(out_db->NewTransaction());
|
||||
int count = 0;
|
||||
for (; cursor->Valid(); cursor->Next()) {
|
||||
caffe::Datum datum;
|
||||
CAFFE_ENFORCE(datum.ParseFromString(cursor->value()));
|
||||
TensorProtos protos;
|
||||
TensorProto* data = protos.add_protos();
|
||||
TensorProto* label = protos.add_protos();
|
||||
label->set_data_type(TensorProto::INT32);
|
||||
label->add_dims(1);
|
||||
label->add_int32_data(datum.label());
|
||||
if (datum.encoded()) {
|
||||
// This is an encoded image. we will copy over the data directly.
|
||||
data->set_data_type(TensorProto::STRING);
|
||||
data->add_dims(1);
|
||||
data->add_string_data(datum.data());
|
||||
} else {
|
||||
// float data not supported right now.
|
||||
CAFFE_ENFORCE_EQ(datum.float_data_size(), 0);
|
||||
std::vector<char> buffer_vec(datum.data().size());
|
||||
char* buffer = buffer_vec.data();
|
||||
// swap order from CHW to HWC
|
||||
int channels = datum.channels();
|
||||
int size = datum.height() * datum.width();
|
||||
CAFFE_ENFORCE_EQ(datum.data().size(), channels * size);
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
char* dst = buffer + c;
|
||||
const char* src = datum.data().c_str() + c * size;
|
||||
for (int n = 0; n < size; ++n) {
|
||||
dst[n*channels] = src[n];
|
||||
}
|
||||
}
|
||||
data->set_data_type(TensorProto::BYTE);
|
||||
data->add_dims(datum.height());
|
||||
data->add_dims(datum.width());
|
||||
data->add_dims(datum.channels());
|
||||
data->set_byte_data(buffer, datum.data().size());
|
||||
}
|
||||
transaction->Put(cursor->key(), protos.SerializeAsString());
|
||||
if (++count % caffe2::FLAGS_batch_size == 0) {
|
||||
transaction->Commit();
|
||||
LOG(INFO) << "Converted " << count << " items so far.";
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "A total of " << count << " items processed.";
|
||||
return 0;
|
||||
}
|
||||
|
51
binaries/convert_db.cc
Normal file
51
binaries/convert_db.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(input_db, "", "The input db.");
|
||||
CAFFE2_DEFINE_string(input_db_type, "", "The input db type.");
|
||||
CAFFE2_DEFINE_string(output_db, "", "The output db.");
|
||||
CAFFE2_DEFINE_string(output_db_type, "", "The output db type.");
|
||||
CAFFE2_DEFINE_int(batch_size, 1000, "The write batch size.");
|
||||
|
||||
using caffe2::db::Cursor;
|
||||
using caffe2::db::DB;
|
||||
using caffe2::db::Transaction;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
|
||||
std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
|
||||
caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db, caffe2::db::READ));
|
||||
std::unique_ptr<DB> out_db(caffe2::db::CreateDB(
|
||||
caffe2::FLAGS_output_db_type, caffe2::FLAGS_output_db, caffe2::db::NEW));
|
||||
std::unique_ptr<Cursor> cursor(in_db->NewCursor());
|
||||
std::unique_ptr<Transaction> transaction(out_db->NewTransaction());
|
||||
int count = 0;
|
||||
for (; cursor->Valid(); cursor->Next()) {
|
||||
transaction->Put(cursor->key(), cursor->value());
|
||||
if (++count % caffe2::FLAGS_batch_size == 0) {
|
||||
transaction->Commit();
|
||||
LOG(INFO) << "Converted " << count << " items so far.";
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "A total of " << count << " items processed.";
|
||||
return 0;
|
||||
}
|
156
binaries/convert_encoded_to_raw_leveldb.cc
Normal file
156
binaries/convert_encoded_to_raw_leveldb.cc
Normal file
@ -0,0 +1,156 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// This script converts an image dataset to leveldb.
|
||||
//
|
||||
// caffe2::FLAGS_input_folder is the root folder that holds all the images, and
|
||||
// caffe2::FLAGS_list_file should be a list of files as well as their labels, in the
|
||||
// format as
|
||||
// subfolder1/file1.JPEG 7
|
||||
// ....
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
#include <fstream> // NOLINT(readability/streams)
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "leveldb/db.h"
|
||||
#include "leveldb/write_batch.h"
|
||||
|
||||
CAFFE2_DEFINE_string(input_db_name, "", "The input image file name.");
|
||||
CAFFE2_DEFINE_string(output_db_name, "", "The output training leveldb name.");
|
||||
CAFFE2_DEFINE_bool(color, true, "If set, load images in color.");
|
||||
CAFFE2_DEFINE_int(scale, 256,
|
||||
"If caffe2::FLAGS_raw is set, scale all the images' shorter edge to the given "
|
||||
"value.");
|
||||
CAFFE2_DEFINE_bool(warp, false, "If warp is set, warp the images to square.");
|
||||
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
|
||||
void ConvertToRawDataset(
|
||||
const string& input_db_name, const string& output_db_name) {
|
||||
// input leveldb
|
||||
std::unique_ptr<leveldb::DB> input_db;
|
||||
LOG(INFO) << "Opening input leveldb " << input_db_name;
|
||||
{
|
||||
leveldb::Options options;
|
||||
options.create_if_missing = false;
|
||||
leveldb::DB* db_temp;
|
||||
leveldb::Status status = leveldb::DB::Open(
|
||||
options, input_db_name, &db_temp);
|
||||
CAFFE_ENFORCE(status.ok(), "Failed to open leveldb ", input_db_name, ".");
|
||||
input_db.reset(db_temp);
|
||||
}
|
||||
|
||||
// output leveldb
|
||||
std::unique_ptr<leveldb::DB> output_db;
|
||||
std::unique_ptr<leveldb::WriteBatch> batch;
|
||||
LOG(INFO) << "Opening leveldb " << output_db_name;
|
||||
{
|
||||
leveldb::Options options;
|
||||
options.error_if_exists = true;
|
||||
options.create_if_missing = true;
|
||||
options.write_buffer_size = 268435456;
|
||||
leveldb::DB* db_temp;
|
||||
leveldb::Status status = leveldb::DB::Open(
|
||||
options, output_db_name, &db_temp);
|
||||
CAFFE_ENFORCE(
|
||||
status.ok(),
|
||||
"Failed to open leveldb ",
|
||||
output_db_name,
|
||||
". Is it already existing?");
|
||||
output_db.reset(db_temp);
|
||||
}
|
||||
batch.reset(new leveldb::WriteBatch());
|
||||
|
||||
TensorProtos input_protos;
|
||||
TensorProtos output_protos;
|
||||
TensorProto* data = output_protos.add_protos();
|
||||
TensorProto* label = output_protos.add_protos();
|
||||
data->set_data_type(TensorProto::BYTE);
|
||||
data->add_dims(0);
|
||||
data->add_dims(0);
|
||||
if (caffe2::FLAGS_color) {
|
||||
data->add_dims(3);
|
||||
}
|
||||
string value;
|
||||
|
||||
unique_ptr<leveldb::Iterator> iter;
|
||||
iter.reset(input_db->NewIterator(leveldb::ReadOptions()));
|
||||
iter->SeekToFirst();
|
||||
int count = 0;
|
||||
for (; iter->Valid(); iter->Next()) {
|
||||
CAFFE_ENFORCE(input_protos.ParseFromString(iter->value().ToString()));
|
||||
label->CopyFrom(input_protos.protos(1));
|
||||
const string& encoded_image = input_protos.protos(0).string_data(0);
|
||||
int encoded_size = encoded_image.size();
|
||||
cv::Mat img = cv::imdecode(
|
||||
cv::Mat(1, &encoded_size, CV_8UC1,
|
||||
const_cast<char*>(encoded_image.data())),
|
||||
caffe2::FLAGS_color ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE);
|
||||
cv::Mat resized_img;
|
||||
int scaled_width, scaled_height;
|
||||
if (caffe2::FLAGS_warp) {
|
||||
scaled_width = caffe2::FLAGS_scale;
|
||||
scaled_height = caffe2::FLAGS_scale;
|
||||
} else if (img.rows > img.cols) {
|
||||
scaled_width = caffe2::FLAGS_scale;
|
||||
scaled_height = static_cast<float>(img.rows) * caffe2::FLAGS_scale / img.cols;
|
||||
} else {
|
||||
scaled_height = caffe2::FLAGS_scale;
|
||||
scaled_width = static_cast<float>(img.cols) * caffe2::FLAGS_scale / img.rows;
|
||||
}
|
||||
cv::resize(img, resized_img, cv::Size(scaled_width, scaled_height), 0, 0,
|
||||
cv::INTER_LINEAR);
|
||||
data->set_dims(0, scaled_height);
|
||||
data->set_dims(1, scaled_width);
|
||||
DCHECK(resized_img.isContinuous());
|
||||
data->set_byte_data(resized_img.ptr(),
|
||||
scaled_height * scaled_width * (caffe2::FLAGS_color ? 3 : 1));
|
||||
output_protos.SerializeToString(&value);
|
||||
// Put in db
|
||||
batch->Put(iter->key(), value);
|
||||
if (++count % 1000 == 0) {
|
||||
output_db->Write(leveldb::WriteOptions(), batch.get());
|
||||
batch.reset(new leveldb::WriteBatch());
|
||||
LOG(INFO) << "Processed " << count << " files.";
|
||||
}
|
||||
}
|
||||
// write the last batch
|
||||
if (count % 1000 != 0) {
|
||||
output_db->Write(leveldb::WriteOptions(), batch.get());
|
||||
}
|
||||
LOG(INFO) << "Processed a total of " << count << " files.";
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::ConvertToRawDataset(
|
||||
caffe2::FLAGS_input_db_name, caffe2::FLAGS_output_db_name);
|
||||
return 0;
|
||||
}
|
223
binaries/core_overhead_benchmark.cc
Normal file
223
binaries/core_overhead_benchmark.cc
Normal file
@ -0,0 +1,223 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "benchmark/benchmark.h"
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
#define CAFFE2_SKIP_IF_NO_GPU \
|
||||
if (!caffe2::NumCudaDevices()) { \
|
||||
state.SkipWithError("No CUDA available, skipping benchmark."); \
|
||||
return; \
|
||||
}
|
||||
|
||||
using namespace caffe2;
|
||||
|
||||
static void BM_CUDAContextCreation(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
volatile CUDAContext context_so_we_do_initialization_work;
|
||||
while (state.KeepRunning()) {
|
||||
volatile CUDAContext context;
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_CUDAContextCreation);
|
||||
|
||||
static void BM_CUDAContextStreamAccess(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
CUDAContext context;
|
||||
while (state.KeepRunning()) {
|
||||
volatile cudaStream_t stream = context.cuda_stream();
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_CUDAContextStreamAccess);
|
||||
|
||||
static void BM_cudaGetDevice(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
int id;
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaGetDevice(&id));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaGetDevice);
|
||||
|
||||
static void BM_cudaSetDevice(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
int total = NumCudaDevices();
|
||||
int i = 0;
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaSetDevice((i++) % total));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaSetDevice);
|
||||
|
||||
static void BM_cudaSetAndGetDevice(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
int total = NumCudaDevices();
|
||||
int i = 0;
|
||||
int id;
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaSetDevice((i++) % total));
|
||||
CUDA_ENFORCE(cudaGetDevice(&id));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaSetAndGetDevice);
|
||||
|
||||
static void BM_cudaSetSameDevice(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaSetDevice(0));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaSetSameDevice);
|
||||
|
||||
static void BM_cudaStreamCreateSyncDelete(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
cudaStream_t stream;
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaStreamCreate(&stream));
|
||||
CUDA_ENFORCE(cudaStreamSynchronize(stream));
|
||||
CUDA_ENFORCE(cudaStreamDestroy(stream));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaStreamCreateSyncDelete);
|
||||
|
||||
static void BM_cudaStreamSynchronize(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
cudaStream_t stream;
|
||||
CUDA_ENFORCE(cudaStreamCreate(&stream));
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaStreamSynchronize(stream));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaStreamSynchronize);
|
||||
|
||||
static void BM_cudaEventRecord(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
cudaStream_t stream;
|
||||
cudaEvent_t event;
|
||||
CUDA_ENFORCE(cudaStreamCreate(&stream));
|
||||
CUDA_ENFORCE(cudaEventCreateWithFlags(
|
||||
&event, cudaEventDefault | cudaEventDisableTiming));
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaEventRecord(event, stream));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaEventRecord);
|
||||
|
||||
static void BM_cudaStreamWaitEventThenStreamSynchronize(
|
||||
benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
cudaStream_t stream;
|
||||
cudaEvent_t event;
|
||||
CUDA_ENFORCE(cudaStreamCreate(&stream));
|
||||
CUDA_ENFORCE(cudaEventCreateWithFlags(
|
||||
&event, cudaEventDefault | cudaEventDisableTiming));
|
||||
CUDA_ENFORCE(cudaEventRecord(event, stream));
|
||||
CUDA_ENFORCE(cudaStreamWaitEvent(stream, event, 0));
|
||||
CUDA_ENFORCE(cudaStreamSynchronize(stream));
|
||||
while (state.KeepRunning()) {
|
||||
CUDA_ENFORCE(cudaStreamWaitEvent(stream, event, 0));
|
||||
CUDA_ENFORCE(cudaStreamSynchronize(stream));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_cudaStreamWaitEventThenStreamSynchronize);
|
||||
|
||||
static void BM_CudaPointerAffinity(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
TensorCUDA tensor(vector<TIndex>{1, 2, 3, 4});
|
||||
float* ptr = tensor.mutable_data<float>();
|
||||
while (state.KeepRunning()) {
|
||||
volatile int id = GetGPUIDForPointer(ptr);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_CudaPointerAffinity);
|
||||
|
||||
namespace {
|
||||
template <class Context>
|
||||
class DummyEmptyOp : public Operator<Context> {
|
||||
public:
|
||||
DummyEmptyOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws) {}
|
||||
|
||||
bool RunOnDevice() final { return true; }
|
||||
};
|
||||
|
||||
REGISTER_CPU_OPERATOR(DummyEmpty, DummyEmptyOp<CPUContext>);
|
||||
REGISTER_CUDA_OPERATOR(DummyEmpty, DummyEmptyOp<CUDAContext>);
|
||||
OPERATOR_SCHEMA(DummyEmpty);
|
||||
} // namespace
|
||||
|
||||
static void BM_OperatorCreationCPU(benchmark::State& state) {
|
||||
std::unique_ptr<OperatorBase> op;
|
||||
OperatorDef def;
|
||||
Workspace ws;
|
||||
def.set_type("DummyEmpty");
|
||||
def.mutable_device_option()->set_device_type(CPU);
|
||||
while (state.KeepRunning()) {
|
||||
op = CreateOperator(def, &ws);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_OperatorCreationCPU);
|
||||
|
||||
static void BM_OperatorCreationCUDA(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
std::unique_ptr<OperatorBase> op;
|
||||
OperatorDef def;
|
||||
Workspace ws;
|
||||
def.set_type("DummyEmpty");
|
||||
def.mutable_device_option()->set_device_type(CUDA);
|
||||
while (state.KeepRunning()) {
|
||||
op = CreateOperator(def, &ws);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_OperatorCreationCUDA);
|
||||
|
||||
static void BM_RawAllocDeallocCPU(benchmark::State& state) {
|
||||
while (state.KeepRunning()) {
|
||||
// Allocating only 1 byte in order to measure the overhead.
|
||||
auto ptr_and_deleter = GetCPUAllocator()->New(1);
|
||||
// Deallocate.
|
||||
ptr_and_deleter.second(ptr_and_deleter.first);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_RawAllocDeallocCPU);
|
||||
|
||||
static void BM_TensorAllocDeallocCPU(benchmark::State& state) {
|
||||
Tensor<CPUContext> tensor;
|
||||
// small allocation
|
||||
tensor.Resize(32, 32);
|
||||
while (state.KeepRunning()) {
|
||||
CHECK(tensor.mutable_data<float>());
|
||||
tensor.FreeMemory();
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_TensorAllocDeallocCPU);
|
||||
|
||||
static void BM_TensorAllocDeallocCUDA(benchmark::State& state) {
|
||||
CAFFE2_SKIP_IF_NO_GPU;
|
||||
Tensor<CUDAContext> tensor;
|
||||
// small allocation
|
||||
tensor.Resize(32, 32);
|
||||
while (state.KeepRunning()) {
|
||||
CHECK(tensor.mutable_data<float>());
|
||||
tensor.FreeMemory();
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_TensorAllocDeallocCUDA);
|
||||
|
||||
BENCHMARK_MAIN()
|
98
binaries/db_throughput.cc
Normal file
98
binaries/db_throughput.cc
Normal file
@ -0,0 +1,98 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/timer.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(input_db, "", "The input db.");
|
||||
CAFFE2_DEFINE_string(input_db_type, "", "The input db type.");
|
||||
CAFFE2_DEFINE_int(report_interval, 1000, "The report interval.");
|
||||
CAFFE2_DEFINE_int(repeat, 10, "The number to repeat the throughput test.");
|
||||
CAFFE2_DEFINE_bool(use_reader, false, "If true, use the reader interface.");
|
||||
CAFFE2_DEFINE_int(num_read_threads, 1,
|
||||
"The number of concurrent reading threads.");
|
||||
|
||||
using caffe2::db::Cursor;
|
||||
using caffe2::db::DB;
|
||||
using caffe2::db::DBReader;
|
||||
using caffe2::string;
|
||||
|
||||
void TestThroughputWithDB() {
|
||||
std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
|
||||
caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db, caffe2::db::READ));
|
||||
std::unique_ptr<Cursor> cursor(in_db->NewCursor());
|
||||
for (int iter_id = 0; iter_id < caffe2::FLAGS_repeat; ++iter_id) {
|
||||
caffe2::Timer timer;
|
||||
for (int i = 0; i < caffe2::FLAGS_report_interval; ++i) {
|
||||
string key = cursor->key();
|
||||
string value = cursor->value();
|
||||
//VLOG(1) << "Key " << key;
|
||||
cursor->Next();
|
||||
if (!cursor->Valid()) {
|
||||
cursor->SeekToFirst();
|
||||
}
|
||||
}
|
||||
double elapsed_seconds = timer.Seconds();
|
||||
printf("Iteration %03d, took %4.5f seconds, throughput %f items/sec.\n",
|
||||
iter_id, elapsed_seconds,
|
||||
caffe2::FLAGS_report_interval / elapsed_seconds);
|
||||
}
|
||||
}
|
||||
|
||||
void TestThroughputWithReaderWorker(const DBReader* reader, int thread_id) {
|
||||
string key, value;
|
||||
for (int iter_id = 0; iter_id < caffe2::FLAGS_repeat; ++iter_id) {
|
||||
caffe2::Timer timer;
|
||||
for (int i = 0; i < caffe2::FLAGS_report_interval; ++i) {
|
||||
reader->Read(&key, &value);
|
||||
}
|
||||
double elapsed_seconds = timer.Seconds();
|
||||
printf("Thread %03d iteration %03d, took %4.5f seconds, "
|
||||
"throughput %f items/sec.\n",
|
||||
thread_id, iter_id, elapsed_seconds,
|
||||
caffe2::FLAGS_report_interval / elapsed_seconds);
|
||||
}
|
||||
}
|
||||
|
||||
void TestThroughputWithReader() {
|
||||
caffe2::db::DBReader reader(
|
||||
caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db);
|
||||
std::vector<std::unique_ptr<std::thread>> reading_threads(
|
||||
caffe2::FLAGS_num_read_threads);
|
||||
for (int i = 0; i < reading_threads.size(); ++i) {
|
||||
reading_threads[i].reset(new std::thread(
|
||||
TestThroughputWithReaderWorker, &reader, i));
|
||||
}
|
||||
for (int i = 0; i < reading_threads.size(); ++i) {
|
||||
reading_threads[i]->join();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
if (caffe2::FLAGS_use_reader) {
|
||||
TestThroughputWithReader();
|
||||
} else {
|
||||
TestThroughputWithDB();
|
||||
}
|
||||
return 0;
|
||||
}
|
57
binaries/inspect_gpus.cc
Normal file
57
binaries/inspect_gpus.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
using std::vector;
|
||||
|
||||
CAFFE2_DECLARE_int(caffe2_log_level);
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::SetUsageMessage(
|
||||
"Inspects the GPUs on the current machine and prints out their details "
|
||||
"provided by cuda.");
|
||||
|
||||
int gpu_count;
|
||||
CUDA_ENFORCE(cudaGetDeviceCount(&gpu_count));
|
||||
for (int i = 0; i < gpu_count; ++i) {
|
||||
LOG(INFO) << "Querying device ID = " << i;
|
||||
caffe2::DeviceQuery(i);
|
||||
}
|
||||
|
||||
vector<vector<bool> > access_pattern;
|
||||
CAFFE_ENFORCE(caffe2::GetCudaPeerAccessPattern(&access_pattern));
|
||||
|
||||
std::stringstream sstream;
|
||||
// Find topology
|
||||
for (int i = 0; i < gpu_count; ++i) {
|
||||
for (int j = 0; j < gpu_count; ++j) {
|
||||
sstream << (access_pattern[i][j] ? "+" : "-") << " ";
|
||||
}
|
||||
sstream << std::endl;
|
||||
}
|
||||
LOG(INFO) << "Access pattern: " << std::endl << sstream.str();
|
||||
|
||||
return 0;
|
||||
}
|
148
binaries/make_cifar_db.cc
Normal file
148
binaries/make_cifar_db.cc
Normal file
@ -0,0 +1,148 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
//
|
||||
// This script converts the CIFAR dataset to the leveldb format used
|
||||
// by caffe to perform classification.
|
||||
// Usage:
|
||||
// convert_cifar_data input_folder output_db_file
|
||||
// The CIFAR dataset could be downloaded at
|
||||
// http://www.cs.toronto.edu/~kriz/cifar.html
|
||||
|
||||
#include <array>
|
||||
#include <fstream> // NOLINT(readability/streams)
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(input_folder, "", "The input folder name.");
|
||||
CAFFE2_DEFINE_string(output_train_db_name,
|
||||
"", "The output training db name.");
|
||||
CAFFE2_DEFINE_string(output_test_db_name,
|
||||
"", "The output testing db name.");
|
||||
CAFFE2_DEFINE_string(db, "leveldb", "The db type.");
|
||||
CAFFE2_DEFINE_bool(is_cifar100, false,
|
||||
"If set, convert cifar100. Otherwise do cifar10.");
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
using std::stringstream;
|
||||
|
||||
const int kCIFARSize = 32;
|
||||
const int kCIFARImageNBytes = kCIFARSize * kCIFARSize * 3;
|
||||
const int kCIFAR10BatchSize = 10000;
|
||||
const int kCIFAR10TestDataSize = 10000;
|
||||
const int kCIFAR10TrainBatches = 5;
|
||||
|
||||
const int kCIFAR100TrainDataSize = 50000;
|
||||
const int kCIFAR100TestDataSize = 10000;
|
||||
|
||||
void ReadImage(std::ifstream* file, int* label, char* buffer) {
|
||||
char label_char;
|
||||
if (caffe2::FLAGS_is_cifar100) {
|
||||
// Skip the coarse label.
|
||||
file->read(&label_char, 1);
|
||||
}
|
||||
file->read(&label_char, 1);
|
||||
*label = label_char;
|
||||
// Yes, there are better ways to do it, like in-place swap... but I am too
|
||||
// lazy so let's just write it in a memory-wasteful way.
|
||||
std::array<char, kCIFARImageNBytes> channel_first_storage;
|
||||
file->read(channel_first_storage.data(), kCIFARImageNBytes);
|
||||
for (int c = 0; c < 3; ++c) {
|
||||
for (int i = 0; i < kCIFARSize * kCIFARSize; ++i) {
|
||||
buffer[i * 3 + c] =
|
||||
channel_first_storage[c * kCIFARSize * kCIFARSize + i];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void WriteToDB(const string& filename, const int num_items,
|
||||
const int& offset, db::DB* db) {
|
||||
TensorProtos protos;
|
||||
TensorProto* data = protos.add_protos();
|
||||
TensorProto* label = protos.add_protos();
|
||||
data->set_data_type(TensorProto::BYTE);
|
||||
data->add_dims(kCIFARSize);
|
||||
data->add_dims(kCIFARSize);
|
||||
data->add_dims(3);
|
||||
label->set_data_type(TensorProto::INT32);
|
||||
label->add_dims(1);
|
||||
label->add_int32_data(0);
|
||||
|
||||
LOG(INFO) << "Converting file " << filename;
|
||||
std::ifstream data_file(filename.c_str(),
|
||||
std::ios::in | std::ios::binary);
|
||||
CAFFE_ENFORCE(data_file, "Unable to open file ", filename);
|
||||
char str_buffer[kCIFARImageNBytes];
|
||||
int label_value;
|
||||
string serialized_protos;
|
||||
std::unique_ptr<db::Transaction> transaction(db->NewTransaction());
|
||||
for (int itemid = 0; itemid < num_items; ++itemid) {
|
||||
ReadImage(&data_file, &label_value, str_buffer);
|
||||
data->set_byte_data(str_buffer, kCIFARImageNBytes);
|
||||
label->set_int32_data(0, label_value);
|
||||
protos.SerializeToString(&serialized_protos);
|
||||
snprintf(str_buffer, kCIFARImageNBytes, "%05d",
|
||||
offset + itemid);
|
||||
transaction->Put(string(str_buffer), serialized_protos);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertCIFAR() {
|
||||
std::unique_ptr<db::DB> train_db(
|
||||
db::CreateDB(caffe2::FLAGS_db, caffe2::FLAGS_output_train_db_name,
|
||||
db::NEW));
|
||||
std::unique_ptr<db::DB> test_db(
|
||||
db::CreateDB(caffe2::FLAGS_db, caffe2::FLAGS_output_test_db_name,
|
||||
db::NEW));
|
||||
|
||||
if (!caffe2::FLAGS_is_cifar100) {
|
||||
// This is cifar 10.
|
||||
for (int fileid = 0; fileid < kCIFAR10TrainBatches; ++fileid) {
|
||||
stringstream train_file;
|
||||
train_file << caffe2::FLAGS_input_folder << "/data_batch_" << fileid + 1
|
||||
<< ".bin";
|
||||
WriteToDB(train_file.str(), kCIFAR10BatchSize,
|
||||
fileid * kCIFAR10BatchSize, train_db.get());
|
||||
}
|
||||
stringstream test_file;
|
||||
test_file << caffe2::FLAGS_input_folder << "/test_batch.bin";
|
||||
WriteToDB(test_file.str(), kCIFAR10TestDataSize, 0, test_db.get());
|
||||
} else {
|
||||
// This is cifar 100.
|
||||
stringstream train_file;
|
||||
train_file << caffe2::FLAGS_input_folder << "/train.bin";
|
||||
WriteToDB(train_file.str(), kCIFAR100TrainDataSize, 0, train_db.get());
|
||||
stringstream test_file;
|
||||
test_file << caffe2::FLAGS_input_folder << "/test.bin";
|
||||
WriteToDB(test_file.str(), kCIFAR100TestDataSize, 0, test_db.get());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::ConvertCIFAR();
|
||||
return 0;
|
||||
}
|
280
binaries/make_image_db.cc
Normal file
280
binaries/make_image_db.cc
Normal file
@ -0,0 +1,280 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// This script converts an image dataset to a database.
|
||||
//
|
||||
// caffe2::FLAGS_input_folder is the root folder that holds all the images
|
||||
//
|
||||
// caffe2::FLAGS_list_file is the path to a file containing a list of files
|
||||
// and their labels, as follows:
|
||||
//
|
||||
// subfolder1/file1.JPEG 7
|
||||
// subfolder1/file2.JPEG 7
|
||||
// subfolder2/file1.JPEG 8
|
||||
// ...
|
||||
//
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_bool(shuffle, false,
|
||||
"Randomly shuffle the order of images and their labels");
|
||||
CAFFE2_DEFINE_string(input_folder, "", "The input image file name.");
|
||||
CAFFE2_DEFINE_string(
|
||||
list_file,
|
||||
"",
|
||||
"The text file containing the list of images.");
|
||||
CAFFE2_DEFINE_string(output_db_name, "", "The output training leveldb name.");
|
||||
CAFFE2_DEFINE_string(db, "leveldb", "The db type.");
|
||||
CAFFE2_DEFINE_bool(raw, false,
|
||||
"If set, we pre-read the images and store the raw buffer.");
|
||||
CAFFE2_DEFINE_bool(color, true, "If set, load images in color.");
|
||||
CAFFE2_DEFINE_int(
|
||||
scale,
|
||||
256,
|
||||
"If caffe2::FLAGS_raw is set, scale the shorter edge to the given value.");
|
||||
CAFFE2_DEFINE_bool(warp, false, "If warp is set, warp the images to square.");
|
||||
CAFFE2_DEFINE_int(
|
||||
num_threads,
|
||||
-1,
|
||||
"Number of image parsing and conversion threads.");
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
class Converter {
|
||||
public:
|
||||
explicit Converter() {
|
||||
data_ = protos_.add_protos();
|
||||
label_ = protos_.add_protos();
|
||||
if (caffe2::FLAGS_raw) {
|
||||
data_->set_data_type(TensorProto::BYTE);
|
||||
data_->add_dims(0);
|
||||
data_->add_dims(0);
|
||||
if (caffe2::FLAGS_color) {
|
||||
data_->add_dims(3);
|
||||
}
|
||||
} else {
|
||||
data_->set_data_type(TensorProto::STRING);
|
||||
data_->add_dims(1);
|
||||
data_->add_string_data("");
|
||||
}
|
||||
label_->set_data_type(TensorProto::INT32);
|
||||
label_->add_dims(1);
|
||||
label_->add_int32_data(0);
|
||||
}
|
||||
|
||||
~Converter() {
|
||||
if (thread_.joinable()) {
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void queue(const std::pair<std::string, int>& pair) {
|
||||
in_.push(pair);
|
||||
}
|
||||
|
||||
void start() {
|
||||
thread_ = std::thread(&Converter::run, this);
|
||||
}
|
||||
|
||||
std::string get() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (out_.empty()) {
|
||||
cv_.wait(lock);
|
||||
}
|
||||
|
||||
auto value = out_.front();
|
||||
out_.pop();
|
||||
cv_.notify_one();
|
||||
return value;
|
||||
}
|
||||
|
||||
void run() {
|
||||
const auto& input_folder = caffe2::FLAGS_input_folder;
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
std::string value;
|
||||
while (!in_.empty()) {
|
||||
auto pair = in_.front();
|
||||
in_.pop();
|
||||
lock.unlock();
|
||||
|
||||
label_->set_int32_data(0, pair.second);
|
||||
|
||||
// Add raw file contents to DB if !raw
|
||||
if (!caffe2::FLAGS_raw) {
|
||||
std::ifstream image_file_stream(input_folder + pair.first);
|
||||
if (!image_file_stream) {
|
||||
LOG(ERROR) << "Cannot open " << input_folder << pair.first
|
||||
<< ". Skipping.";
|
||||
} else {
|
||||
data_->mutable_string_data(0)->assign(
|
||||
std::istreambuf_iterator<char>(image_file_stream),
|
||||
std::istreambuf_iterator<char>());
|
||||
}
|
||||
} else {
|
||||
// Load image
|
||||
cv::Mat img = cv::imread(
|
||||
input_folder + pair.first,
|
||||
caffe2::FLAGS_color ? CV_LOAD_IMAGE_COLOR
|
||||
: CV_LOAD_IMAGE_GRAYSCALE);
|
||||
|
||||
// Resize image
|
||||
cv::Mat resized_img;
|
||||
int scaled_width, scaled_height;
|
||||
if (caffe2::FLAGS_warp) {
|
||||
scaled_width = caffe2::FLAGS_scale;
|
||||
scaled_height = caffe2::FLAGS_scale;
|
||||
} else if (img.rows > img.cols) {
|
||||
scaled_width = caffe2::FLAGS_scale;
|
||||
scaled_height =
|
||||
static_cast<float>(img.rows) * caffe2::FLAGS_scale / img.cols;
|
||||
} else {
|
||||
scaled_height = caffe2::FLAGS_scale;
|
||||
scaled_width =
|
||||
static_cast<float>(img.cols) * caffe2::FLAGS_scale / img.rows;
|
||||
}
|
||||
cv::resize(
|
||||
img,
|
||||
resized_img,
|
||||
cv::Size(scaled_width, scaled_height),
|
||||
0,
|
||||
0,
|
||||
cv::INTER_LINEAR);
|
||||
data_->set_dims(0, scaled_height);
|
||||
data_->set_dims(1, scaled_width);
|
||||
|
||||
// Assert we don't have to deal with alignment
|
||||
DCHECK(resized_img.isContinuous());
|
||||
auto nbytes = resized_img.total() * resized_img.elemSize();
|
||||
data_->set_byte_data(resized_img.ptr(), nbytes);
|
||||
}
|
||||
|
||||
protos_.SerializeToString(&value);
|
||||
|
||||
// Add serialized proto to out queue or wait if it is not empty
|
||||
lock.lock();
|
||||
while (!out_.empty()) {
|
||||
cv_.wait(lock);
|
||||
}
|
||||
out_.push(value);
|
||||
cv_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
TensorProtos protos_;
|
||||
TensorProto* data_;
|
||||
TensorProto* label_;
|
||||
std::queue<std::pair<std::string, int>> in_;
|
||||
std::queue<std::string> out_;
|
||||
|
||||
std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
std::thread thread_;
|
||||
};
|
||||
|
||||
void ConvertImageDataset(
|
||||
const string& input_folder,
|
||||
const string& list_filename,
|
||||
const string& output_db_name,
|
||||
const bool /*shuffle*/) {
|
||||
std::ifstream list_file(list_filename);
|
||||
std::vector<std::pair<std::string, int> > lines;
|
||||
std::string filename;
|
||||
int file_label;
|
||||
while (list_file >> filename >> file_label) {
|
||||
lines.push_back(std::make_pair(filename, file_label));
|
||||
}
|
||||
|
||||
if (caffe2::FLAGS_shuffle) {
|
||||
LOG(INFO) << "Shuffling data";
|
||||
std::shuffle(lines.begin(), lines.end(), std::default_random_engine(1701));
|
||||
}
|
||||
|
||||
auto num_threads = caffe2::FLAGS_num_threads;
|
||||
if (num_threads < 1) {
|
||||
num_threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
LOG(INFO) << "Processing " << lines.size() << " images...";
|
||||
LOG(INFO) << "Opening DB " << output_db_name;
|
||||
|
||||
auto db = db::CreateDB(caffe2::FLAGS_db, output_db_name, db::NEW);
|
||||
auto transaction = db->NewTransaction();
|
||||
|
||||
LOG(INFO) << "Using " << num_threads << " processing threads...";
|
||||
std::vector<Converter> converters(num_threads);
|
||||
|
||||
// Queue entries across converters
|
||||
for (auto i = 0; i < lines.size(); i++) {
|
||||
converters[i % converters.size()].queue(lines[i]);
|
||||
}
|
||||
|
||||
// Start all converters
|
||||
for (auto& converter : converters) {
|
||||
converter.start();
|
||||
}
|
||||
|
||||
constexpr auto key_max_length = 256;
|
||||
char key_cstr[key_max_length];
|
||||
string value;
|
||||
int count = 0;
|
||||
for (auto i = 0; i < lines.size(); i++) {
|
||||
// Get serialized proto for this entry
|
||||
auto value = converters[i % converters.size()].get();
|
||||
|
||||
// Synthesize key for this entry
|
||||
auto key_len = snprintf(
|
||||
key_cstr, sizeof(key_cstr), "%08d_%s", i, lines[i].first.c_str());
|
||||
DCHECK_LE(key_len, sizeof(key_cstr));
|
||||
|
||||
// Put in db
|
||||
transaction->Put(string(key_cstr), value);
|
||||
|
||||
if (++count % 1000 == 0) {
|
||||
// Commit the current writes.
|
||||
transaction->Commit();
|
||||
LOG(INFO) << "Processed " << count << " files.";
|
||||
}
|
||||
}
|
||||
|
||||
// Commit final transaction
|
||||
transaction->Commit();
|
||||
LOG(INFO) << "Processed " << count << " files.";
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::ConvertImageDataset(
|
||||
caffe2::FLAGS_input_folder, caffe2::FLAGS_list_file,
|
||||
caffe2::FLAGS_output_db_name, caffe2::FLAGS_shuffle);
|
||||
return 0;
|
||||
}
|
139
binaries/make_mnist_db.cc
Normal file
139
binaries/make_mnist_db.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// This script converts the MNIST dataset to leveldb.
|
||||
// The MNIST dataset could be downloaded at
|
||||
// http://yann.lecun.com/exdb/mnist/
|
||||
|
||||
#include <fstream> // NOLINT(readability/streams)
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(image_file, "", "The input image file name.");
|
||||
CAFFE2_DEFINE_string(label_file, "", "The label file name.");
|
||||
CAFFE2_DEFINE_string(output_file, "", "The output db name.");
|
||||
CAFFE2_DEFINE_string(db, "leveldb", "The db type.");
|
||||
CAFFE2_DEFINE_int(data_limit, -1,
|
||||
"If set, only output this number of data points.");
|
||||
CAFFE2_DEFINE_bool(channel_first, false,
|
||||
"If set, write the data as channel-first (CHW order) as the old "
|
||||
"Caffe does.");
|
||||
|
||||
namespace caffe2 {
|
||||
uint32_t swap_endian(uint32_t val) {
|
||||
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
|
||||
return (val << 16) | (val >> 16);
|
||||
}
|
||||
|
||||
void convert_dataset(const char* image_filename, const char* label_filename,
|
||||
const char* db_path, const int data_limit) {
|
||||
// Open files
|
||||
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
|
||||
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
|
||||
CAFFE_ENFORCE(image_file, "Unable to open file ", image_filename);
|
||||
CAFFE_ENFORCE(label_file, "Unable to open file ", label_filename);
|
||||
// Read the magic and the meta data
|
||||
uint32_t magic;
|
||||
uint32_t num_items;
|
||||
uint32_t num_labels;
|
||||
uint32_t rows;
|
||||
uint32_t cols;
|
||||
|
||||
image_file.read(reinterpret_cast<char*>(&magic), 4);
|
||||
magic = swap_endian(magic);
|
||||
if (magic == 529205256) {
|
||||
LOG(FATAL) <<
|
||||
"It seems that you forgot to unzip the mnist dataset. You should "
|
||||
"first unzip them using e.g. gunzip on Linux.";
|
||||
}
|
||||
CAFFE_ENFORCE_EQ(magic, 2051, "Incorrect image file magic.");
|
||||
label_file.read(reinterpret_cast<char*>(&magic), 4);
|
||||
magic = swap_endian(magic);
|
||||
CAFFE_ENFORCE_EQ(magic, 2049, "Incorrect label file magic.");
|
||||
image_file.read(reinterpret_cast<char*>(&num_items), 4);
|
||||
num_items = swap_endian(num_items);
|
||||
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
|
||||
num_labels = swap_endian(num_labels);
|
||||
CAFFE_ENFORCE_EQ(num_items, num_labels);
|
||||
image_file.read(reinterpret_cast<char*>(&rows), 4);
|
||||
rows = swap_endian(rows);
|
||||
image_file.read(reinterpret_cast<char*>(&cols), 4);
|
||||
cols = swap_endian(cols);
|
||||
|
||||
// leveldb
|
||||
std::unique_ptr<db::DB> mnist_db(db::CreateDB(caffe2::FLAGS_db, db_path, db::NEW));
|
||||
std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
|
||||
// Storing to db
|
||||
char label_value;
|
||||
std::vector<char> pixels(rows * cols);
|
||||
int count = 0;
|
||||
const int kMaxKeyLength = 10;
|
||||
char key_cstr[kMaxKeyLength];
|
||||
string value;
|
||||
|
||||
TensorProtos protos;
|
||||
TensorProto* data = protos.add_protos();
|
||||
TensorProto* label = protos.add_protos();
|
||||
data->set_data_type(TensorProto::BYTE);
|
||||
if (caffe2::FLAGS_channel_first) {
|
||||
data->add_dims(1);
|
||||
data->add_dims(rows);
|
||||
data->add_dims(cols);
|
||||
} else {
|
||||
data->add_dims(rows);
|
||||
data->add_dims(cols);
|
||||
data->add_dims(1);
|
||||
}
|
||||
label->set_data_type(TensorProto::INT32);
|
||||
label->add_int32_data(0);
|
||||
|
||||
LOG(INFO) << "A total of " << num_items << " items.";
|
||||
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
|
||||
for (int item_id = 0; item_id < num_items; ++item_id) {
|
||||
image_file.read(pixels.data(), rows * cols);
|
||||
label_file.read(&label_value, 1);
|
||||
for (int i = 0; i < rows * cols; ++i) {
|
||||
data->set_byte_data(pixels.data(), rows * cols);
|
||||
}
|
||||
label->set_int32_data(0, static_cast<int>(label_value));
|
||||
snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
|
||||
protos.SerializeToString(&value);
|
||||
string keystr(key_cstr);
|
||||
|
||||
// Put in db
|
||||
transaction->Put(keystr, value);
|
||||
if (++count % 1000 == 0) {
|
||||
transaction->Commit();
|
||||
}
|
||||
if (data_limit > 0 && count == data_limit) {
|
||||
LOG(INFO) << "Reached data limit of " << data_limit << ", stop.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::convert_dataset(caffe2::FLAGS_image_file.c_str(), caffe2::FLAGS_label_file.c_str(),
|
||||
caffe2::FLAGS_output_file.c_str(), caffe2::FLAGS_data_limit);
|
||||
return 0;
|
||||
}
|
57
binaries/predictor_verifier.cc
Normal file
57
binaries/predictor_verifier.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/core/flags.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/predictor.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
|
||||
CAFFE2_DEFINE_string(init_net, "", "The given path to the init protobuffer.");
|
||||
CAFFE2_DEFINE_string(
|
||||
predict_net,
|
||||
"",
|
||||
"The given path to the predict protobuffer.");
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
void run() {
|
||||
if (FLAGS_init_net.empty()) {
|
||||
LOG(FATAL) << "No init net specified. Use --init_net=/path/to/net.";
|
||||
}
|
||||
if (FLAGS_predict_net.empty()) {
|
||||
LOG(FATAL) << "No predict net specified. Use --predict_net=/path/to/net.";
|
||||
}
|
||||
caffe2::NetDef init_net, predict_net;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_init_net, &init_net));
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_predict_net, &predict_net));
|
||||
// Can be large due to constant fills
|
||||
VLOG(1) << "Init net: " << ProtoDebugString(init_net);
|
||||
LOG(INFO) << "Predict net: " << ProtoDebugString(predict_net);
|
||||
auto predictor = caffe2::make_unique<Predictor>(init_net, predict_net);
|
||||
LOG(INFO) << "Checking that a null forward-pass works";
|
||||
Predictor::TensorVector inputVec, outputVec;
|
||||
predictor->run(inputVec, &outputVec);
|
||||
CAFFE_ENFORCE_GT(outputVec.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::run();
|
||||
// This is to allow us to use memory leak checks.
|
||||
google::protobuf::ShutdownProtobufLibrary();
|
||||
return 0;
|
||||
}
|
42
binaries/print_core_object_sizes.cc
Normal file
42
binaries/print_core_object_sizes.cc
Normal file
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
|
||||
#define PRINT_SIZE(cls) \
|
||||
std::cout << "Size of " #cls ": " << sizeof(cls) << " bytes." \
|
||||
<< std::endl;
|
||||
|
||||
int main(int /* unused */, char** /* unused */) {
|
||||
PRINT_SIZE(caffe2::Blob);
|
||||
PRINT_SIZE(caffe2::Tensor<caffe2::CPUContext>);
|
||||
PRINT_SIZE(caffe2::Tensor<caffe2::CUDAContext>);
|
||||
PRINT_SIZE(caffe2::CPUContext);
|
||||
PRINT_SIZE(caffe2::CUDAContext);
|
||||
PRINT_SIZE(caffe2::OperatorBase);
|
||||
PRINT_SIZE(caffe2::OperatorDef);
|
||||
PRINT_SIZE(caffe2::Operator<caffe2::CPUContext>);
|
||||
PRINT_SIZE(caffe2::Operator<caffe2::CUDAContext>);
|
||||
PRINT_SIZE(caffe2::TypeMeta);
|
||||
PRINT_SIZE(caffe2::Workspace);
|
||||
return 0;
|
||||
}
|
73
binaries/print_registered_core_operators.cc
Normal file
73
binaries/print_registered_core_operators.cc
Normal file
@ -0,0 +1,73 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/operator_schema.h"
|
||||
|
||||
CAFFE2_DEFINE_string(schema, "",
|
||||
"Print doc and schema of a particular operator");
|
||||
|
||||
static bool HasSchema(const std::string& str) {
|
||||
return caffe2::OpSchemaRegistry::Schema(str);
|
||||
}
|
||||
|
||||
static bool HasDoc(const std::string& str) {
|
||||
const auto* schema = caffe2::OpSchemaRegistry::Schema(str);
|
||||
return (schema != nullptr) && (schema->doc() != nullptr);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
|
||||
if (!caffe2::FLAGS_schema.empty()) {
|
||||
const auto* schema = caffe2::OpSchemaRegistry::Schema(
|
||||
caffe2::FLAGS_schema);
|
||||
if (!schema) {
|
||||
std::cerr << "Operator " << caffe2::FLAGS_schema
|
||||
<< " doesn't have a schema" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
std::cout << "Operator " << caffe2::FLAGS_schema << ": " << std::endl
|
||||
<< *schema;
|
||||
return 0;
|
||||
}
|
||||
|
||||
for (const auto& pair : *caffe2::gDeviceTypeRegistry()) {
|
||||
std::cout << "Device type " << pair.first
|
||||
#ifndef CAFFE2_USE_LITE_PROTO
|
||||
<< " (" << caffe2::DeviceType_Name(
|
||||
static_cast<caffe2::DeviceType>(pair.first))
|
||||
<< ")"
|
||||
#endif
|
||||
<< std::endl;
|
||||
for (const auto& key : pair.second->Keys()) {
|
||||
std::cout << "\t(schema: " << HasSchema(key) << ", doc: " << HasDoc(key)
|
||||
<< ")\t" << key << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Operators that have gradients registered:" << std::endl;
|
||||
for (const auto& key : caffe2::GradientRegistry()->Keys()) {
|
||||
std::cout << "\t(schema: " << HasSchema(key) << ", doc: "
|
||||
<< HasDoc(key) << ")\t"
|
||||
<< key << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
40
binaries/run_plan.cc
Normal file
40
binaries/run_plan.cc
Normal file
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(plan, "", "The given path to the plan protobuffer.");
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
if (caffe2::FLAGS_plan.size() == 0) {
|
||||
LOG(ERROR) << "No plan specified. Use --plan=/path/to/plan.";
|
||||
return 0;
|
||||
}
|
||||
LOG(INFO) << "Loading plan: " << caffe2::FLAGS_plan;
|
||||
caffe2::PlanDef plan_def;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_plan, &plan_def));
|
||||
std::unique_ptr<caffe2::Workspace> workspace(new caffe2::Workspace());
|
||||
workspace->RunPlan(plan_def);
|
||||
|
||||
// This is to allow us to use memory leak checks.
|
||||
google::protobuf::ShutdownProtobufLibrary();
|
||||
return 0;
|
||||
}
|
48
binaries/run_plan_mpi.cc
Normal file
48
binaries/run_plan_mpi.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(plan, "", "The given path to the plan protobuffer.");
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::SetUsageMessage("Runs a caffe2 plan that has MPI operators in it.");
|
||||
int mpi_ret;
|
||||
MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mpi_ret);
|
||||
if (mpi_ret != MPI_THREAD_MULTIPLE &&
|
||||
mpi_ret != MPI_THREAD_SERIALIZED) {
|
||||
std::cerr << "Caffe2 MPI requires the underlying MPI to support the "
|
||||
"MPI_THREAD_SERIALIZED or MPI_THREAD_MULTIPLE mode.\n";
|
||||
return 1;
|
||||
}
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
LOG(INFO) << "Loading plan: " << caffe2::FLAGS_plan;
|
||||
caffe2::PlanDef plan_def;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_plan, &plan_def));
|
||||
std::unique_ptr<caffe2::Workspace> workspace(new caffe2::Workspace());
|
||||
workspace->RunPlan(plan_def);
|
||||
|
||||
// This is to allow us to use memory leak checks.
|
||||
google::protobuf::ShutdownProtobufLibrary();
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
169
binaries/speed_benchmark.cc
Normal file
169
binaries/speed_benchmark.cc
Normal file
@ -0,0 +1,169 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
|
||||
CAFFE2_DEFINE_string(net, "", "The given net to benchmark.");
|
||||
CAFFE2_DEFINE_string(
|
||||
init_net,
|
||||
"",
|
||||
"The given net to initialize any parameters.");
|
||||
CAFFE2_DEFINE_string(
|
||||
input,
|
||||
"",
|
||||
"Input that is needed for running the network. If "
|
||||
"multiple input needed, use comma separated string.");
|
||||
CAFFE2_DEFINE_string(
|
||||
input_file,
|
||||
"",
|
||||
"Input file that contain the serialized protobuf for "
|
||||
"the input blobs. If multiple input needed, use comma "
|
||||
"separated string. Must have the same number of items "
|
||||
"as input does.");
|
||||
CAFFE2_DEFINE_string(
|
||||
input_dims,
|
||||
"",
|
||||
"Alternate to input_files, if all inputs are simple "
|
||||
"float TensorCPUs, specify the dimension using comma "
|
||||
"separated numbers. If multiple input needed, use "
|
||||
"semicolon to separate the dimension of different "
|
||||
"tensors.");
|
||||
CAFFE2_DEFINE_string(
|
||||
output,
|
||||
"",
|
||||
"Output that should be dumped after the execution "
|
||||
"finishes. If multiple outputs are needed, use comma "
|
||||
"separated string. If you want to dump everything, pass "
|
||||
"'*' as the output value.");
|
||||
CAFFE2_DEFINE_string(
|
||||
output_folder,
|
||||
"",
|
||||
"The folder that the output should be written to. This "
|
||||
"folder must already exist in the file system.");
|
||||
CAFFE2_DEFINE_int(warmup, 0, "The number of iterations to warm up.");
|
||||
CAFFE2_DEFINE_int(iter, 10, "The number of iterations to run.");
|
||||
CAFFE2_DEFINE_bool(
|
||||
run_individual,
|
||||
false,
|
||||
"Whether to benchmark individual operators.");
|
||||
|
||||
CAFFE2_DEFINE_bool(force_engine, false, "Force engine field for all operators");
|
||||
CAFFE2_DEFINE_string(engine, "", "Forced engine field value");
|
||||
CAFFE2_DEFINE_bool(force_algo, false, "Force algo arg for all operators");
|
||||
CAFFE2_DEFINE_string(algo, "", "Forced algo arg value");
|
||||
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
using std::vector;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
unique_ptr<caffe2::Workspace> workspace(new caffe2::Workspace());
|
||||
|
||||
// Run initialization network.
|
||||
caffe2::NetDef net_def;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_init_net, &net_def));
|
||||
CAFFE_ENFORCE(workspace->RunNetOnce(net_def));
|
||||
|
||||
// Load input.
|
||||
if (caffe2::FLAGS_input.size()) {
|
||||
vector<string> input_names = caffe2::split(',', caffe2::FLAGS_input);
|
||||
if (caffe2::FLAGS_input_file.size()) {
|
||||
vector<string> input_files = caffe2::split(',', caffe2::FLAGS_input_file);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_files.size(),
|
||||
"Input name and file should have the same number.");
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
caffe2::BlobProto blob_proto;
|
||||
CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto));
|
||||
workspace->CreateBlob(input_names[i])->Deserialize(blob_proto);
|
||||
}
|
||||
} else if (caffe2::FLAGS_input_dims.size()) {
|
||||
vector<string> input_dims_list =
|
||||
caffe2::split(';', caffe2::FLAGS_input_dims);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_dims_list.size(),
|
||||
"Input name and dims should have the same number of items.");
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
vector<string> input_dims_str = caffe2::split(',', input_dims_list[i]);
|
||||
vector<int> input_dims;
|
||||
for (const string& s : input_dims_str) {
|
||||
input_dims.push_back(caffe2::stoi(s));
|
||||
}
|
||||
caffe2::TensorCPU* tensor =
|
||||
workspace->GetBlob(input_names[i])->GetMutable<caffe2::TensorCPU>();
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<float>();
|
||||
}
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"You requested input tensors, but neither input_file nor "
|
||||
"input_dims is set.");
|
||||
}
|
||||
}
|
||||
|
||||
// Run main network.
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def));
|
||||
// force changing engine and algo
|
||||
if (caffe2::FLAGS_force_engine) {
|
||||
LOG(INFO) << "force engine be: " << caffe2::FLAGS_engine;
|
||||
for (const auto& op : net_def.op()) {
|
||||
const_cast<caffe2::OperatorDef*>(&op)->set_engine(caffe2::FLAGS_engine);
|
||||
}
|
||||
}
|
||||
if (caffe2::FLAGS_force_algo) {
|
||||
LOG(INFO) << "force algo be: " << caffe2::FLAGS_algo;
|
||||
for (const auto& op : net_def.op()) {
|
||||
caffe2::GetMutableArgument(
|
||||
"algo", true, const_cast<caffe2::OperatorDef*>(&op))
|
||||
->set_s(caffe2::FLAGS_algo);
|
||||
}
|
||||
}
|
||||
caffe2::NetBase* net = workspace->CreateNet(net_def);
|
||||
CHECK_NOTNULL(net);
|
||||
net->TEST_Benchmark(
|
||||
caffe2::FLAGS_warmup, caffe2::FLAGS_iter, caffe2::FLAGS_run_individual);
|
||||
|
||||
string output_prefix = caffe2::FLAGS_output_folder.size()
|
||||
? caffe2::FLAGS_output_folder + "/"
|
||||
: "";
|
||||
if (caffe2::FLAGS_output.size()) {
|
||||
vector<string> output_names = caffe2::split(',', caffe2::FLAGS_output);
|
||||
if (caffe2::FLAGS_output == "*") {
|
||||
output_names = workspace->Blobs();
|
||||
}
|
||||
for (const string& name : output_names) {
|
||||
CAFFE_ENFORCE(
|
||||
workspace->HasBlob(name),
|
||||
"You requested a non-existing blob: ",
|
||||
name);
|
||||
string serialized = workspace->GetBlob(name)->Serialize(name);
|
||||
string output_filename = output_prefix + name;
|
||||
caffe2::WriteStringToFile(serialized, output_filename.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
77
binaries/split_db.cc
Normal file
77
binaries/split_db.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
CAFFE2_DEFINE_string(input_db, "", "The input db.");
|
||||
CAFFE2_DEFINE_int(splits, 0, "The number of splits.");
|
||||
CAFFE2_DEFINE_string(db_type, "", "The db type.");
|
||||
CAFFE2_DEFINE_int(batch_size, 1000, "The write batch size.");
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
static int Split(int argc, char** argv) {
|
||||
GlobalInit(&argc, &argv);
|
||||
|
||||
CAFFE_ENFORCE(FLAGS_input_db.size(), "Must specify --input_db=/path/to/db.");
|
||||
CAFFE_ENFORCE(FLAGS_splits > 0, "Must specify a nonnegative split number.");
|
||||
CAFFE_ENFORCE(FLAGS_db_type.size(), "Must specify --db_type=[a db type].");
|
||||
|
||||
unique_ptr<db::DB> in_db(
|
||||
db::CreateDB(FLAGS_db_type, FLAGS_input_db, db::READ));
|
||||
CAFFE_ENFORCE(in_db != nullptr, "Cannot open input db: ", FLAGS_input_db);
|
||||
unique_ptr<db::Cursor> cursor(in_db->NewCursor());
|
||||
// This usually won't happen, but FWIW.
|
||||
CAFFE_ENFORCE(
|
||||
cursor != nullptr, "Cannot obtain cursor for input db: ", FLAGS_input_db);
|
||||
|
||||
vector<unique_ptr<db::DB>> out_dbs;
|
||||
vector<unique_ptr<db::Transaction>> transactions;
|
||||
for (int i = 0; i < FLAGS_splits; ++i) {
|
||||
out_dbs.push_back(unique_ptr<db::DB>(db::CreateDB(
|
||||
FLAGS_db_type, FLAGS_input_db + "_split_" + to_string(i), db::NEW)));
|
||||
CAFFE_ENFORCE(out_dbs.back().get(), "Cannot create output db #", i);
|
||||
transactions.push_back(
|
||||
unique_ptr<db::Transaction>(out_dbs[i]->NewTransaction()));
|
||||
CAFFE_ENFORCE(
|
||||
transactions.back().get(), "Cannot get transaction for output db #", i);
|
||||
}
|
||||
|
||||
int count = 0;
|
||||
for (; cursor->Valid(); cursor->Next()) {
|
||||
transactions[count % FLAGS_splits]->Put(cursor->key(), cursor->value());
|
||||
if (++count % FLAGS_batch_size == 0) {
|
||||
for (int i = 0; i < FLAGS_splits; ++i) {
|
||||
transactions[i]->Commit();
|
||||
}
|
||||
LOG(INFO) << "Split " << count << " items so far.";
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "A total of " << count << " items processed.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
return caffe2::Split(argc, argv);
|
||||
}
|
89
binaries/tutorial_blob.cc
Normal file
89
binaries/tutorial_blob.cc
Normal file
@ -0,0 +1,89 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/core/blob.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
// We will be lazy and just use the whole namespace.
|
||||
using namespace caffe2;
|
||||
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::ShowLogInfoToStderr();
|
||||
|
||||
LOG(INFO) <<
|
||||
"This script corresponds to the Blob part of the Caffe2 C++ "
|
||||
"tutorial.";
|
||||
|
||||
LOG(INFO) << "Let's create a blob myblob.";
|
||||
|
||||
Blob myblob;
|
||||
|
||||
LOG(INFO) << "Let's set it to int and set the value to 10.";
|
||||
|
||||
int* myint = myblob.GetMutable<int>();
|
||||
*myint = 10;
|
||||
|
||||
LOG(INFO)
|
||||
<< "Is the blob type int? "
|
||||
<< myblob.IsType<int>();
|
||||
|
||||
LOG(INFO)
|
||||
<< "Is the blob type float? "
|
||||
<< myblob.IsType<float>();
|
||||
|
||||
const int& myint_const = myblob.Get<int>();
|
||||
LOG(INFO)
|
||||
<< "The value of the int number stored in the blob is: "
|
||||
<< myint_const;
|
||||
|
||||
LOG(INFO)
|
||||
<< "Let's try to get a float pointer. This will trigger an exception.";
|
||||
|
||||
try {
|
||||
const float& myfloat = myblob.Get<float>();
|
||||
LOG(FATAL) << "This line should never happen.";
|
||||
} catch (std::exception& e) {
|
||||
LOG(INFO)
|
||||
<< "As expected, we got an exception. Its content says: "
|
||||
<< e.what();
|
||||
}
|
||||
|
||||
LOG(INFO) <<
|
||||
"However, we can change the content type (and destroy the old "
|
||||
"content) by calling GetMutable. Let's change it to double.";
|
||||
|
||||
double* mydouble = myblob.GetMutable<double>();
|
||||
*mydouble = 3.14;
|
||||
|
||||
LOG(INFO) << "The new content is: " << myblob.Get<double>();
|
||||
|
||||
LOG(INFO) <<
|
||||
"If we have a pre-created object, we can use Reset() to transfer the "
|
||||
"object to a blob.";
|
||||
|
||||
std::string* pvec = new std::string();
|
||||
myblob.Reset(pvec); // no need to release pvec, myblob takes ownership.
|
||||
|
||||
LOG(INFO) << "Is the blob now of type string? "
|
||||
<< myblob.IsType<std::string>();
|
||||
|
||||
LOG(INFO) << "This concludes the blob tutorial.";
|
||||
return 0;
|
||||
}
|
66
binaries/zmq_feeder.cc
Normal file
66
binaries/zmq_feeder.cc
Normal file
@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// This binary provides an easy way to open a zeromq server and feeds data to
|
||||
// clients connect to it. It uses the Caffe2 db as the backend, thus allowing
|
||||
// one to convert any db-compliant storage to a zeromq service.
|
||||
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/utils/zmq_helper.h"
|
||||
|
||||
CAFFE2_DEFINE_string(server, "tcp://*:5555", "The server address.");
|
||||
CAFFE2_DEFINE_string(input_db, "", "The input db.");
|
||||
CAFFE2_DEFINE_string(input_db_type, "", "The input db type.");
|
||||
|
||||
using caffe2::db::DB;
|
||||
using caffe2::db::Cursor;
|
||||
using caffe2::string;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
|
||||
LOG(INFO) << "Opening DB...";
|
||||
auto in_db = caffe2::db::CreateDB(
|
||||
caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db, caffe2::db::READ);
|
||||
CAFFE_ENFORCE(
|
||||
in_db,
|
||||
"Cannot load input db " + caffe2::FLAGS_input_db + " of expected type " +
|
||||
caffe2::FLAGS_input_db_type);
|
||||
auto cursor = in_db->NewCursor();
|
||||
LOG(INFO) << "DB opened.";
|
||||
|
||||
LOG(INFO) << "Starting ZeroMQ server...";
|
||||
|
||||
// Socket to talk to clients
|
||||
caffe2::ZmqSocket sender(ZMQ_PUSH);
|
||||
sender.Bind(caffe2::FLAGS_server);
|
||||
LOG(INFO) << "Server created at " << caffe2::FLAGS_server;
|
||||
|
||||
while (1) {
|
||||
VLOG(1) << "Sending " << cursor->key();
|
||||
sender.SendTillSuccess(cursor->key(), ZMQ_SNDMORE);
|
||||
sender.SendTillSuccess(cursor->value(), 0);
|
||||
cursor->Next();
|
||||
if (!cursor->Valid()) {
|
||||
cursor->SeekToFirst();
|
||||
}
|
||||
}
|
||||
// We do not do an elegant quit since this binary is going to be terminated by
|
||||
// control+C.
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user