mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update from facebook (#7451)
* [bootcamp] Improve "Shape" operator to support axes specification To improve .shape operator of Caffe2 to support x.shape(tensor, axes), which takes an optional int array "axes" as input. For example, x.shape(tensor, [1, 0]) will return the dimension for axis 1 and 0 following the specified order. For current version, "axes" input allows duplications and can have arbitrary length. * Back out "Add barrier net that runs before training nets" Original commit changeset: b373fdc9c30f. Need additional changes to some callers to support barrier failures. * Change warning to verbose log to reduce log spam The `LOG(WARNING)` was a bit spammy for regular use so lets just make it a `VLOG`. * Extract the shared code from different caffe2_benchmark binaries The OSS benchmark and Internal benchmark will share most functions in the benchmark. * Support MFR in sequence training As titled. * Make knowledge distillation work with using logged prediction feature as teacher label. 1) Add loading raw dense feature as teacher label. 2) Optional calibration function for teacher label 3) Add teacher label into generic unit test 4) Deprecated TTSN workflow version using feature_options to config teacher label * [C2/CUDA]: unjoined cross entropy sigmoid as desc * Add async_scheduling executor into deferrable_net_exec_test Add async_scheduling into tests and fix some exception cases * Fix Event disabled error When disabling event in RNN ops make sure we don't call Finish on disabled event from op's RunAsync * cuda ensure cpu output op can handle both TensorCPU and TensorCUDA as desc. * [C2 Core] Infer input device option in C2 hypothesis_test checkers Improve how we default input blob device options. Previously it defaults as where op lives but it is not necessarily the case. For example: CopyCPUToGPU * [C2 Op]SplitByLengthsOp CPU/GPU implementation [C2 Op]SplitByLengthsOp CPU/GPU implementation * fix undefined symbol error not sure why we're getting undefined symbol even with link_whole = True Need to figure out why but need this workaround for now * Add tools in DAIPlayground platform to help debugging models Add additional tools to allow Plauground override individual method defined in AnyExp. This will allow user to create module that specificly change certain default method behavior. An example included in this diff is deactivating test model and checkpointing. When debugging any model problems, switching off components helps me quickly narrow down the location of the bug. The technique is extensively used in task T27038712 (Steady memory increase in EDPM, eventually resulting in gloo/cuda.cu:34: out of memory) * add shape and type inference for int8 conversion operator * Fix flaky test for group_norm Fix flaky test for group_norm * Fix group_norm_op_test flaky Fix group_norm_op_test flaky * Implementation of composite learning rate policy In many state-of-the-arts deep learning works, people use a simple trick to schedule the learning rate: use a fixed learning rate until error plateaus and then switch to a different fixed learning rate, and so on. In this diff, we implemented a simple version of the composite learning rate. The user gives a set of learning rates policies and corresponding iteration nums, and the optimizer will change the learning rate policy based on the number of iterations so far. For example, the user give two learning rate policies, one is FixedLearningRate and PolyLearningRate, with an iteration number of 1k. Then the first 1k iteration, we use FixedLearningRate. For the following iterations, we use PolyLearningRate. * Split two use cases of CachedReader into two classes, DBFileReader and CachedReader # Use Cases: 1). input: DB file -> output: DatasetReader. Use DBFileReader. 2). input: Reader -> build cache DB file -> output: DatasetReader. Use CachedReader. # Changes to CachedReader: 1). Move db_path to the constructor. Because in mock reader. cache will always be built ahead. # Changes to tests: 1). Make a separate TestCase class for CachedReader and DBFileReader. 2). Make it possible to add more test functions by adding setUp, tearDown and _make_temp_path. 3). Make delete db_path more general. `db_path` could be a file for `log_file_db`, but could also be a directory for `leveldb`. * Back out "On Mobile phones, call GlobalInit with no arguments in predictor in case we need to perform initialization" Original commit changeset: 4489c6133f11 * Fix LARS bug Fixed a bug in the LARS implementation which caused all subsequent blobs not using LARS to have the LARS learning rate multiplier applied to them. * [tum] support sparse init & add uniformFill option as title * Propagate exception for async nets Capture the exception when an exception is thrown in async nets and re-throw it after wait(). This allows exceptions to be propagated up to the caller. This diff was a part of D7752068. We split the diff so that C2 core files changes are in a separate diff. * Automatic update of fbcode/onnx to 69894f207dfcd72d1e70497d387201cec327efbc Previous import was 403ccfbd0161c38f0834413d790bad0874afbf9a Included changes: - **[69894f2](https://github.com/onnx/onnx/commit/69894f2)**: Use op schema.all tensor types in random like definitions (#865) <Scott McKay> - **[b9d6b90](https://github.com/onnx/onnx/commit/b9d6b90)**: Clarify random like operators (#846) <Scott McKay> - **[fc6b5fb](https://github.com/onnx/onnx/commit/fc6b5fb)**: Refactor shape inference implementation (#855) <anderspapitto> - **[b7d8dc8](https://github.com/onnx/onnx/commit/b7d8dc8)**: fix cmake warning message (#863) <Eric S. Yu> - **[f585c5d](https://github.com/onnx/onnx/commit/f585c5d)**: add pytorch-operator test for tile (#831) <Wenhao Hu> - **[993fe70](https://github.com/onnx/onnx/commit/993fe70)**: add install step (#832) <Eric S. Yu> - **[68bc26c](https://github.com/onnx/onnx/commit/68bc26c)**: add type inference for traditional ml ops except classifier ops. (#857) <Ke Zhang> - **[9cc0cda](https://github.com/onnx/onnx/commit/9cc0cda)**: fix string representation of scalar types (#858) <G. Ramalingam> - **[1078925](https://github.com/onnx/onnx/commit/1078925)**: fix y in pow test case to scalar (#852) <Wenhao Hu> - **[c66fb6f](https://github.com/onnx/onnx/commit/c66fb6f)**: Add some math function shape inference (#845) <anderspapitto> - **[ff667d1](https://github.com/onnx/onnx/commit/ff667d1)**: Refactor return type and docs for ONNXIFI_BACKEND_DIRECTX_ID (#853) <Marat Dukhan> - **[11c6876](https://github.com/onnx/onnx/commit/11c6876)**: clear initializer names when clear initializer (#849) <Wenhao Hu> - **[73c34ae](https://github.com/onnx/onnx/commit/73c34ae)**: Clarify FeatureVectorizer description. (#843) <Scott McKay> - **[1befb9b](https://github.com/onnx/onnx/commit/1befb9b)**: Remove useless text in docs (#850) <Lu Fang> - **[e84788f](https://github.com/onnx/onnx/commit/e84788f)**: Fix SELU attributes' default values (#839) <Lu Fang> - **[ebac046](https://github.com/onnx/onnx/commit/ebac046)**: Add tile test case (#823) <Wenhao Hu> - **[8b7a925](https://github.com/onnx/onnx/commit/8b7a925)**: a few more shape inference functions (#772) <anderspapitto> - **[9718f42](https://github.com/onnx/onnx/commit/9718f42)**: Make the coefficient non optional for LinearClassifier (#836) <Jaliya Ekanayake> - **[ef083d0](https://github.com/onnx/onnx/commit/ef083d0)**: Add save_tensor and load_tensor functions for Protos (#770) <Lu Fang> - **[45ceb55](https://github.com/onnx/onnx/commit/45ceb55)**: Check if CMAKE_BUILD_TYPE set before project(). (#812) <Sergii Dymchenko> - **[4b3d2b0](https://github.com/onnx/onnx/commit/4b3d2b0)**: [WIP] reenable shape inference tests (#834) <anderspapitto> - **[22d17ee](https://github.com/onnx/onnx/commit/22d17ee)**: RNN tests: LSTM, GRU, SimpleRNN (#739) <Peyman Manikashani> - **[de65b95](https://github.com/onnx/onnx/commit/de65b95)**: dimension denotation (#443) <Tian Jin> - **[eccc76e](https://github.com/onnx/onnx/commit/eccc76e)**: fix field number issue in onnx operator proto and enable its build (#829) <Ke Zhang> - **[d582beb](https://github.com/onnx/onnx/commit/d582beb)**: disable shape inference test to unbreak ci (#830) <Lu Fang> - **[485b787](https://github.com/onnx/onnx/commit/485b787)**: function proto for composite op. (#802) <Ke Zhang> - **[cd58928](https://github.com/onnx/onnx/commit/cd58928)**: specify defaults for attributes of Affine op (#820) <G. Ramalingam> - **[7ee2cf9](https://github.com/onnx/onnx/commit/7ee2cf9)**: merge the dummy backend back into the main one (#743) <anderspapitto> - **[1c03a5a](https://github.com/onnx/onnx/commit/1c03a5a)**: [Proposal] ONNX Interface for Framework Integration (previously ONNX Backend API) header and docs (#551) <Marat Dukhan> - **[3769a98](https://github.com/onnx/onnx/commit/3769a98)**: Rename real model test case from VGG-16 to ZFNet (#821) <Lu Fang> * [C2]ReluN Op relu n op. tf reference: https://www.tensorflow.org/api_docs/python/tf/nn/relu6 * Call destructor when assigning a blob value * Add executor overrides Add executor overrides flag to enable migration to async_scheduling executor * Add barrier net that runs before training nets - attempt #2 Add a synchonize barrier net that is run before training nets. With this net, shards that are faster will wait for other shards before start training. This reduce chances of the faster shards timing out during GLOO AllReduce. Removed explicit data_parallel_model.py.synchronize call in holmes workflow. This change was landed previously but caused errors for some EDPM workflows - See https://fb.facebook.com/groups/1426530000692545/permalink/1906766366002237/ - because EDPM assumes any call to CreateOrCloneCommonWorld and Gloo ops are wrapped in exception handlers but in this case exception thrown in the barrier init net is not handled. To address this issue, we add _CreateOrCloneCommonWorld to the param_init_net instead of a new barrier init net. Since errors for param_init_net run is handled gracefully and re-rendezvous, it should fixes the problem. * Handle empty nets in async_scheduling Make sure we don't get stuck on empty nets * use CUDA_ARCH for conditional compile * [C2 fix] infer function for ensure_cpu_output_op * Update group_norm test to reduce flaky test * Fix lr_multiplier for GPU
This commit is contained in:
committed by
GitHub
parent
947155c69d
commit
b875fb281c
@ -10,6 +10,7 @@ caffe2_binary_target("split_db.cc")
|
||||
|
||||
caffe2_binary_target("db_throughput.cc")
|
||||
|
||||
|
||||
if (USE_CUDA)
|
||||
caffe2_binary_target("inspect_gpus.cc")
|
||||
target_link_libraries(inspect_gpus ${CUDA_LIBRARIES})
|
||||
@ -45,7 +46,10 @@ if (USE_OPENCV)
|
||||
endif()
|
||||
|
||||
if (USE_OBSERVERS)
|
||||
caffe2_binary_target("caffe2_benchmark.cc")
|
||||
add_executable(caffe2_benchmark "caffe2_benchmark.cc" "benchmark_helper.cc")
|
||||
target_link_libraries(caffe2_benchmark ${Caffe2_MAIN_LIBS})
|
||||
target_link_libraries(caffe2_benchmark ${Caffe2_MODULES})
|
||||
install(TARGETS caffe2_benchmark DESTINATION bin)
|
||||
endif()
|
||||
|
||||
# ---[ tutorials
|
||||
|
250
binaries/benchmark_helper.cc
Normal file
250
binaries/benchmark_helper.cc
Normal file
@ -0,0 +1,250 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "binaries/benchmark_helper.h"
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#ifdef __CUDA_ARCH__
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#endif
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
#include "observers/net_observer_reporter_print.h"
|
||||
#include "observers/observer_config.h"
|
||||
#include "observers/perf_observer.h"
|
||||
|
||||
using std::shared_ptr;
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
using std::vector;
|
||||
|
||||
void observerConfig() {
|
||||
caffe2::ClearGlobalNetObservers();
|
||||
caffe2::AddGlobalNetObserverCreator([](caffe2::NetBase* subject) {
|
||||
return caffe2::make_unique<caffe2::PerfNetObserver>(subject);
|
||||
});
|
||||
caffe2::ObserverConfig::setReporter(
|
||||
caffe2::make_unique<caffe2::NetObserverReporterPrint>());
|
||||
}
|
||||
|
||||
bool backendCudaSet(const string& backend) {
|
||||
bool run_on_gpu = false;
|
||||
if (backend == "cuda") {
|
||||
#ifdef __CUDA_ARCH__
|
||||
if (caffe2::HasCudaGPU()) {
|
||||
run_on_gpu = true;
|
||||
} else {
|
||||
CAFFE_THROW("NO GPU support on this host machine");
|
||||
}
|
||||
#else
|
||||
CAFFE_THROW("NO GPU support");
|
||||
#endif
|
||||
}
|
||||
return run_on_gpu;
|
||||
}
|
||||
|
||||
void setDeviceType(caffe2::NetDef* net_def, caffe2::DeviceType& run_dev) {
|
||||
for (int j = 0; j < net_def->op_size(); j++) {
|
||||
caffe2::OperatorDef* op = net_def->mutable_op(j);
|
||||
op->mutable_device_option()->set_device_type(run_dev);
|
||||
}
|
||||
}
|
||||
|
||||
void setOperatorEngine(caffe2::NetDef* net_def, const string& backend) {
|
||||
if (backend != "builtin") {
|
||||
string engine = backend == "nnpack" ? "NNPACK"
|
||||
: backend == "eigen" ? "EIGEN"
|
||||
: backend == "mkl"
|
||||
? "MKLDNN"
|
||||
: backend == "cuda" ? "CUDA"
|
||||
: backend == "default" ? "" : "NONE";
|
||||
CAFFE_ENFORCE(engine != "NONE", "Backend is not supported");
|
||||
for (int i = 0; i < net_def->op_size(); i++) {
|
||||
caffe2::OperatorDef* op_def = net_def->mutable_op(i);
|
||||
op_def->set_engine(engine);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void loadInput(
|
||||
shared_ptr<caffe2::Workspace> workspace,
|
||||
const bool run_on_gpu,
|
||||
const string& input,
|
||||
const string& input_file,
|
||||
const string& input_dims,
|
||||
const string& input_type) {
|
||||
// Load input.
|
||||
if (input.size()) {
|
||||
vector<string> input_names = caffe2::split(',', input);
|
||||
if (input_file.size()) {
|
||||
vector<string> input_files = caffe2::split(',', input_file);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_files.size(),
|
||||
"Input name and file should have the same number.");
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
caffe2::BlobProto blob_proto;
|
||||
CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto));
|
||||
workspace->CreateBlob(input_names[i])->Deserialize(blob_proto);
|
||||
}
|
||||
} else if (input_dims.size() || input_type.size()) {
|
||||
CAFFE_ENFORCE_GE(
|
||||
input_dims.size(),
|
||||
0,
|
||||
"Input dims must be specified when input tensors are used.");
|
||||
CAFFE_ENFORCE_GE(
|
||||
input_type.size(),
|
||||
0,
|
||||
"Input type must be specified when input tensors are used.");
|
||||
|
||||
vector<string> input_dims_list = caffe2::split(';', input_dims);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_dims_list.size(),
|
||||
"Input name and dims should have the same number of items.");
|
||||
vector<string> input_type_list = caffe2::split(';', input_type);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_type_list.size(),
|
||||
"Input name and type should have the same number of items.");
|
||||
for (size_t i = 0; i < input_names.size(); ++i) {
|
||||
vector<string> input_dims_str = caffe2::split(',', input_dims_list[i]);
|
||||
vector<int> input_dims;
|
||||
for (const string& s : input_dims_str) {
|
||||
input_dims.push_back(caffe2::stoi(s));
|
||||
}
|
||||
caffe2::Blob* blob = workspace->GetBlob(input_names[i]);
|
||||
if (blob == nullptr) {
|
||||
blob = workspace->CreateBlob(input_names[i]);
|
||||
}
|
||||
if (run_on_gpu) {
|
||||
LOG(INFO) << "Running on GPU.";
|
||||
#ifdef __CUDA_ARCH__
|
||||
caffe2::TensorCUDA* tensor = blob->GetMutable<caffe2::TensorCUDA>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
tensor->mutable_data<uint8_t>();
|
||||
} else if (input_type_list[i] == "float") {
|
||||
tensor->mutable_data<float>();
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
#else
|
||||
CAFFE_THROW("Not support GPU on mobile.");
|
||||
#endif
|
||||
} else {
|
||||
caffe2::TensorCPU* tensor = blob->GetMutable<caffe2::TensorCPU>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
tensor->mutable_data<uint8_t>();
|
||||
} else if (input_type_list[i] == "float") {
|
||||
tensor->mutable_data<float>();
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"You requested input tensors, but neither input_file nor "
|
||||
"input_dims is set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void runNetwork(
|
||||
shared_ptr<caffe2::Workspace> workspace,
|
||||
caffe2::NetDef& net_def,
|
||||
const bool run_individual,
|
||||
const int warmup,
|
||||
const int iter) {
|
||||
if (!net_def.has_name()) {
|
||||
net_def.set_name("benchmark");
|
||||
}
|
||||
|
||||
caffe2::NetBase* net = workspace->CreateNet(net_def);
|
||||
CHECK_NOTNULL(net);
|
||||
|
||||
LOG(INFO) << "Starting benchmark.";
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, run_individual, warmup);
|
||||
LOG(INFO) << "Running warmup runs.";
|
||||
for (int i = 0; i < warmup; ++i) {
|
||||
CAFFE_ENFORCE(net->Run(), "Warmup run ", i, " has failed.");
|
||||
}
|
||||
|
||||
LOG(INFO) << "Main runs.";
|
||||
CAFFE_ENFORCE(
|
||||
iter >= 0,
|
||||
"Number of main runs should be non negative, provided ",
|
||||
iter,
|
||||
".");
|
||||
for (int i = 0; i < iter; ++i) {
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, 0, warmup);
|
||||
CAFFE_ENFORCE(net->Run(), "Main run ", i, " has failed.");
|
||||
if (run_individual) {
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, 1, warmup);
|
||||
CAFFE_ENFORCE(net->Run(), "Main run ", i, " with operator has failed.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void writeOutput(
|
||||
shared_ptr<caffe2::Workspace> workspace,
|
||||
const bool run_on_gpu,
|
||||
const string& output,
|
||||
const string& output_folder,
|
||||
const bool text_output) {
|
||||
string output_prefix = output_folder.size() ? output_folder + "/" : "";
|
||||
if (output.size()) {
|
||||
vector<string> output_names = caffe2::split(',', output);
|
||||
if (output == "*") {
|
||||
output_names = workspace->Blobs();
|
||||
}
|
||||
for (const string& name : output_names) {
|
||||
CAFFE_ENFORCE(
|
||||
workspace->HasBlob(name),
|
||||
"You requested a non-existing blob: ",
|
||||
name);
|
||||
if (text_output) {
|
||||
if (run_on_gpu) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
writeTextOutput<caffe2::CUDAContext, caffe2::TensorCUDA>(
|
||||
workspace->GetBlob(name)->GetMutable<caffe2::TensorCUDA>(),
|
||||
output_prefix,
|
||||
name);
|
||||
#else
|
||||
CAFFE_THROW("Not support GPU.");
|
||||
#endif
|
||||
} else {
|
||||
writeTextOutput<caffe2::CPUContext, caffe2::TensorCPU>(
|
||||
workspace->GetBlob(name)->GetMutable<caffe2::TensorCPU>(),
|
||||
output_prefix,
|
||||
name);
|
||||
}
|
||||
} else {
|
||||
string serialized = workspace->GetBlob(name)->Serialize(name);
|
||||
string output_filename = output_prefix + name;
|
||||
caffe2::WriteStringToFile(serialized, output_filename.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
91
binaries/benchmark_helper.h
Normal file
91
binaries/benchmark_helper.h
Normal file
@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
|
||||
using std::shared_ptr;
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
template <typename ContextType, typename TensorType>
|
||||
void writeTextOutput(
|
||||
TensorType* tensor,
|
||||
const string& output_prefix,
|
||||
const string& name) {
|
||||
string output_name = output_prefix + "/" + name + ".txt";
|
||||
caffe2::TensorSerializer<ContextType> ser;
|
||||
caffe2::BlobProto blob_proto;
|
||||
ser.Serialize(
|
||||
*tensor, output_name, blob_proto.mutable_tensor(), 0, tensor->size());
|
||||
blob_proto.set_name(output_name);
|
||||
blob_proto.set_type("Tensor");
|
||||
CAFFE_ENFORCE(blob_proto.has_tensor());
|
||||
caffe2::TensorProto tensor_proto = blob_proto.tensor();
|
||||
vector<float> data;
|
||||
switch (tensor_proto.data_type()) {
|
||||
case caffe2::TensorProto::FLOAT: {
|
||||
std::copy(
|
||||
tensor_proto.float_data().begin(),
|
||||
tensor_proto.float_data().end(),
|
||||
std::back_inserter(data));
|
||||
break;
|
||||
}
|
||||
case caffe2::TensorProto::INT32: {
|
||||
std::copy(
|
||||
tensor_proto.int32_data().begin(),
|
||||
tensor_proto.int32_data().end(),
|
||||
std::back_inserter(data));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CAFFE_THROW("Unimplemented Blob type.");
|
||||
}
|
||||
std::ofstream output_file(output_name);
|
||||
std::ostream_iterator<float> output_iterator(output_file, "\n");
|
||||
std::copy(data.begin(), data.end(), output_iterator);
|
||||
}
|
||||
|
||||
void observerConfig();
|
||||
bool backendCudaSet(const string&);
|
||||
void setDeviceType(caffe2::NetDef*, caffe2::DeviceType&);
|
||||
void setOperatorEngine(caffe2::NetDef*, const string&);
|
||||
void loadInput(
|
||||
shared_ptr<caffe2::Workspace>,
|
||||
const bool,
|
||||
const string&,
|
||||
const string&,
|
||||
const string&,
|
||||
const string&);
|
||||
void writeOutput(
|
||||
shared_ptr<caffe2::Workspace>,
|
||||
const bool,
|
||||
const string&,
|
||||
const string&,
|
||||
const bool);
|
||||
void runNetwork(
|
||||
shared_ptr<caffe2::Workspace>,
|
||||
caffe2::NetDef&,
|
||||
const bool,
|
||||
const int,
|
||||
const int);
|
@ -2,24 +2,18 @@
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
#include "binaries/benchmark_helper.h"
|
||||
|
||||
#include "observers/net_observer_reporter_print.h"
|
||||
#include "observers/observer_config.h"
|
||||
#include "observers/perf_observer.h"
|
||||
using std::make_shared;
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
CAFFE2_DEFINE_string(
|
||||
backend,
|
||||
"builtin",
|
||||
"The backend to use when running the model. The allowed "
|
||||
"backend choices are: builtin, default, nnpack, eigen, mkl");
|
||||
"backend choices are: builtin, default, nnpack, eigen, mkl, cuda");
|
||||
|
||||
CAFFE2_DEFINE_string(
|
||||
init_net,
|
||||
"",
|
||||
@ -73,200 +67,52 @@ CAFFE2_DEFINE_bool(
|
||||
"Whether to write out output in text format for regression purpose.");
|
||||
CAFFE2_DEFINE_int(warmup, 0, "The number of iterations to warm up.");
|
||||
|
||||
using std::string;
|
||||
using std::unique_ptr;
|
||||
using std::vector;
|
||||
|
||||
static void writeTextOutput(
|
||||
caffe2::TensorCPU* tensor,
|
||||
const string& output_prefix,
|
||||
const string& name) {
|
||||
string output_name = output_prefix + "/" + name + ".txt";
|
||||
caffe2::TensorSerializer<caffe2::CPUContext> ser;
|
||||
caffe2::BlobProto blob_proto;
|
||||
ser.Serialize(
|
||||
*tensor, output_name, blob_proto.mutable_tensor(), 0, tensor->size());
|
||||
blob_proto.set_name(output_name);
|
||||
blob_proto.set_type("Tensor");
|
||||
CAFFE_ENFORCE(blob_proto.has_tensor());
|
||||
caffe2::TensorProto tensor_proto = blob_proto.tensor();
|
||||
vector<float> data;
|
||||
switch (tensor_proto.data_type()) {
|
||||
case caffe2::TensorProto::FLOAT: {
|
||||
std::copy(
|
||||
tensor_proto.float_data().begin(),
|
||||
tensor_proto.float_data().end(),
|
||||
std::back_inserter(data));
|
||||
break;
|
||||
}
|
||||
case caffe2::TensorProto::INT32: {
|
||||
std::copy(
|
||||
tensor_proto.int32_data().begin(),
|
||||
tensor_proto.int32_data().end(),
|
||||
std::back_inserter(data));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CAFFE_THROW("Unimplemented Blob type.");
|
||||
}
|
||||
std::ofstream output_file(output_name);
|
||||
std::ostream_iterator<float> output_iterator(output_file, "\n");
|
||||
std::copy(data.begin(), data.end(), output_iterator);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
caffe2::ClearGlobalNetObservers();
|
||||
caffe2::AddGlobalNetObserverCreator([](caffe2::NetBase* subject) {
|
||||
return caffe2::make_unique<caffe2::PerfNetObserver>(subject);
|
||||
});
|
||||
caffe2::ObserverConfig::setReporter(
|
||||
caffe2::make_unique<caffe2::NetObserverReporterPrint>());
|
||||
|
||||
observerConfig();
|
||||
caffe2::ShowLogInfoToStderr();
|
||||
unique_ptr<caffe2::Workspace> workspace(new caffe2::Workspace());
|
||||
|
||||
auto workspace = make_shared<caffe2::Workspace>(new caffe2::Workspace());
|
||||
bool run_on_gpu = backendCudaSet(caffe2::FLAGS_backend);
|
||||
|
||||
// support other device type in the future?
|
||||
caffe2::DeviceType run_dev = run_on_gpu ? caffe2::CUDA : caffe2::CPU;
|
||||
|
||||
// Run initialization network.
|
||||
caffe2::NetDef init_net_def;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_init_net, &init_net_def));
|
||||
setDeviceType(&init_net_def, run_dev);
|
||||
setOperatorEngine(&init_net_def, caffe2::FLAGS_backend);
|
||||
CAFFE_ENFORCE(workspace->RunNetOnce(init_net_def));
|
||||
|
||||
// Load input.
|
||||
if (caffe2::FLAGS_input.size()) {
|
||||
vector<string> input_names = caffe2::split(',', caffe2::FLAGS_input);
|
||||
if (caffe2::FLAGS_input_file.size()) {
|
||||
vector<string> input_files = caffe2::split(',', caffe2::FLAGS_input_file);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_files.size(),
|
||||
"Input name and file should have the same number.");
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
caffe2::BlobProto blob_proto;
|
||||
CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto));
|
||||
workspace->CreateBlob(input_names[i])->Deserialize(blob_proto);
|
||||
}
|
||||
} else if (
|
||||
caffe2::FLAGS_input_dims.size() || caffe2::FLAGS_input_type.size()) {
|
||||
CAFFE_ENFORCE_GE(
|
||||
caffe2::FLAGS_input_dims.size(),
|
||||
0,
|
||||
"Input dims must be specified when input tensors are used.");
|
||||
CAFFE_ENFORCE_GE(
|
||||
caffe2::FLAGS_input_type.size(),
|
||||
0,
|
||||
"Input type must be specified when input tensors are used.");
|
||||
|
||||
vector<string> input_dims_list =
|
||||
caffe2::split(';', caffe2::FLAGS_input_dims);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_dims_list.size(),
|
||||
"Input name and dims should have the same number of items.");
|
||||
vector<string> input_type_list =
|
||||
caffe2::split(';', caffe2::FLAGS_input_type);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_names.size(),
|
||||
input_type_list.size(),
|
||||
"Input name and type should have the same number of items.");
|
||||
for (size_t i = 0; i < input_names.size(); ++i) {
|
||||
vector<string> input_dims_str = caffe2::split(',', input_dims_list[i]);
|
||||
vector<int> input_dims;
|
||||
for (const string& s : input_dims_str) {
|
||||
input_dims.push_back(caffe2::stoi(s));
|
||||
}
|
||||
caffe2::Blob* blob = workspace->GetBlob(input_names[i]);
|
||||
if (blob == nullptr) {
|
||||
blob = workspace->CreateBlob(input_names[i]);
|
||||
}
|
||||
caffe2::TensorCPU* tensor = blob->GetMutable<caffe2::TensorCPU>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
tensor->mutable_data<uint8_t>();
|
||||
} else if (input_type_list[i] == "float") {
|
||||
tensor->mutable_data<float>();
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"You requested input tensors, but neither input_file nor "
|
||||
"input_dims is set.");
|
||||
}
|
||||
}
|
||||
|
||||
// Run main network.
|
||||
caffe2::NetDef net_def;
|
||||
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def));
|
||||
if (!net_def.has_name()) {
|
||||
net_def.set_name("benchmark");
|
||||
}
|
||||
if (caffe2::FLAGS_backend != "builtin") {
|
||||
std::string engine = caffe2::FLAGS_backend == "nnpack"
|
||||
? "NNPACK"
|
||||
: caffe2::FLAGS_backend == "eigen" ? "EIGEN"
|
||||
: caffe2::FLAGS_backend == "mkl"
|
||||
? "MKLDNN"
|
||||
: caffe2::FLAGS_backend == "dnnlowp"
|
||||
? "DNNLOWP"
|
||||
: caffe2::FLAGS_backend == "default" ? "" : "NONE";
|
||||
CAFFE_ENFORCE(engine != "NONE", "Backend is not supported");
|
||||
for (int i = 0; i < net_def.op_size(); i++) {
|
||||
caffe2::OperatorDef* op_def = net_def.mutable_op(i);
|
||||
op_def->set_engine(engine);
|
||||
}
|
||||
}
|
||||
setDeviceType(&net_def, run_dev);
|
||||
setOperatorEngine(&net_def, caffe2::FLAGS_backend);
|
||||
|
||||
caffe2::NetBase* net = workspace->CreateNet(net_def);
|
||||
CHECK_NOTNULL(net);
|
||||
loadInput(
|
||||
workspace,
|
||||
run_on_gpu,
|
||||
caffe2::FLAGS_input,
|
||||
caffe2::FLAGS_input_file,
|
||||
caffe2::FLAGS_input_dims,
|
||||
caffe2::FLAGS_input_type);
|
||||
|
||||
LOG(INFO) << "Starting benchmark.";
|
||||
caffe2::ObserverConfig::initSampleRate(
|
||||
1, 1, 1, caffe2::FLAGS_run_individual, caffe2::FLAGS_warmup);
|
||||
LOG(INFO) << "Running warmup runs.";
|
||||
for (int i = 0; i < caffe2::FLAGS_warmup; ++i) {
|
||||
CAFFE_ENFORCE(net->Run(), "Warmup run ", i, " has failed.");
|
||||
}
|
||||
runNetwork(
|
||||
workspace,
|
||||
net_def,
|
||||
caffe2::FLAGS_run_individual,
|
||||
caffe2::FLAGS_warmup,
|
||||
caffe2::FLAGS_iter);
|
||||
|
||||
LOG(INFO) << "Main runs.";
|
||||
CAFFE_ENFORCE(
|
||||
caffe2::FLAGS_iter >= 0,
|
||||
"Number of main runs should be non negative, provided ",
|
||||
caffe2::FLAGS_iter,
|
||||
".");
|
||||
for (int i = 0; i < caffe2::FLAGS_iter; ++i) {
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, 0, caffe2::FLAGS_warmup);
|
||||
CAFFE_ENFORCE(net->Run(), "Main run ", i, " has failed.");
|
||||
if (caffe2::FLAGS_run_individual) {
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, 1, caffe2::FLAGS_warmup);
|
||||
CAFFE_ENFORCE(net->Run(), "Main run ", i, " with operator has failed.");
|
||||
}
|
||||
}
|
||||
|
||||
string output_prefix = caffe2::FLAGS_output_folder.size()
|
||||
? caffe2::FLAGS_output_folder + "/"
|
||||
: "";
|
||||
if (caffe2::FLAGS_output.size()) {
|
||||
vector<string> output_names = caffe2::split(',', caffe2::FLAGS_output);
|
||||
if (caffe2::FLAGS_output == "*") {
|
||||
output_names = workspace->Blobs();
|
||||
}
|
||||
for (const string& name : output_names) {
|
||||
CAFFE_ENFORCE(
|
||||
workspace->HasBlob(name),
|
||||
"You requested a non-existing blob: ",
|
||||
name);
|
||||
if (caffe2::FLAGS_text_output) {
|
||||
auto blob = workspace->GetBlob(name)->GetMutable<caffe2::TensorCPU>();
|
||||
writeTextOutput(blob, output_prefix, name);
|
||||
} else {
|
||||
string serialized = workspace->GetBlob(name)->Serialize(name);
|
||||
string output_filename = output_prefix + name;
|
||||
caffe2::WriteStringToFile(serialized, output_filename.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
writeOutput(
|
||||
workspace,
|
||||
run_on_gpu,
|
||||
caffe2::FLAGS_output,
|
||||
caffe2::FLAGS_output_folder,
|
||||
caffe2::FLAGS_text_output);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
20
caffe2/contrib/ideep/CMakeLists.txt
Normal file
20
caffe2/contrib/ideep/CMakeLists.txt
Normal file
@ -0,0 +1,20 @@
|
||||
if(USE_MKL AND USE_IDEEP AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
|
||||
message(STATUS "Including IDEEP operators")
|
||||
|
||||
# ---[ CPU files.
|
||||
file(GLOB_RECURSE avx2_srcs *.cc)
|
||||
# exclude test files and gpu files
|
||||
file(GLOB_RECURSE tmp *_test.cc)
|
||||
exclude(avx2_srcs "${avx2_srcs}" ${tmp})
|
||||
|
||||
add_library(Caffe2_ideep_operators OBJECT ${avx2_srcs})
|
||||
add_dependencies(Caffe2_ideep_operators Caffe_PROTO Caffe2_PROTO)
|
||||
set_target_properties(Caffe2_ideep_operators PROPERTIES COMPILE_FLAGS "-mavx2")
|
||||
|
||||
# ---[ Send the lists to the parent scope.
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
|
||||
$<TARGET_OBJECTS:Caffe2_ideep_operators>)
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
|
||||
else()
|
||||
message(STATUS "Excluding ideep operators as we are not using ideep")
|
||||
endif()
|
@ -68,68 +68,14 @@ def createTrainerClass(opts):
|
||||
return ModuleRegister.constructTrainerClass(AnyExpTrainer, opts)
|
||||
|
||||
|
||||
def runShardedTrainLoop(opts, myTrainFun):
|
||||
start_epoch = 0
|
||||
pretrained_model = opts['model_param']['pretrained_model']
|
||||
if pretrained_model != '' and os.path.exists(pretrained_model):
|
||||
# Only want to get start_epoch.
|
||||
start_epoch, prev_checkpointed_lr, best_metric = \
|
||||
checkpoint.initialize_params_from_file(
|
||||
model=None,
|
||||
weights_file=pretrained_model,
|
||||
num_xpus=1,
|
||||
opts=opts,
|
||||
broadcast_computed_param=True,
|
||||
reset_epoch=opts['model_param']['reset_epoch'],
|
||||
)
|
||||
log.info('start epoch: {}'.format(start_epoch))
|
||||
pretrained_model = None if pretrained_model == '' else pretrained_model
|
||||
ret = None
|
||||
|
||||
pretrained_model = ""
|
||||
shard_results = []
|
||||
|
||||
for epoch in range(start_epoch,
|
||||
opts['epoch_iter']['num_epochs'],
|
||||
opts['epoch_iter']['num_epochs_per_flow_schedule']):
|
||||
# must support checkpoint or the multiple schedule will always
|
||||
# start from initial state
|
||||
checkpoint_model = None if epoch == start_epoch else ret['model']
|
||||
pretrained_model = None if epoch > start_epoch else pretrained_model
|
||||
shard_results = []
|
||||
# with LexicalContext('epoch{}_gang'.format(epoch),gang_schedule=False):
|
||||
for shard_id in range(opts['distributed']['num_shards']):
|
||||
opts['temp_var']['shard_id'] = shard_id
|
||||
opts['temp_var']['pretrained_model'] = pretrained_model
|
||||
opts['temp_var']['checkpoint_model'] = checkpoint_model
|
||||
opts['temp_var']['epoch'] = epoch
|
||||
opts['temp_var']['start_epoch'] = start_epoch
|
||||
shard_ret = myTrainFun(opts)
|
||||
shard_results.append(shard_ret)
|
||||
|
||||
ret = None
|
||||
# always only take shard_0 return
|
||||
for shard_ret in shard_results:
|
||||
if shard_ret is not None:
|
||||
ret = shard_ret
|
||||
opts['temp_var']['metrics_output'] = ret['metrics']
|
||||
break
|
||||
log.info('ret is: {}'.format(str(ret)))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def trainFun():
|
||||
def simpleTrainFun(opts):
|
||||
trainerClass = createTrainerClass(opts)
|
||||
trainer = trainerClass(opts)
|
||||
return trainer.buildModelAndTrain(opts)
|
||||
return simpleTrainFun
|
||||
def overrideAdditionalMethods(myTrainerClass, opts):
|
||||
return ModuleRegister.overrideAdditionalMethods(myTrainerClass, opts)
|
||||
|
||||
|
||||
def initialize_params_from_file(*args, **kwargs):
|
||||
return checkpoint.initialize_params_from_file(*args, **kwargs)
|
||||
|
||||
|
||||
class AnyExpTrainer(object):
|
||||
|
||||
def __init__(self, opts):
|
||||
@ -321,6 +267,18 @@ class AnyExpTrainer(object):
|
||||
def gen_rendezvous_ctx(self, model, dataset, is_train):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_training_net(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_testing_net(self):
|
||||
if self.test_model is None:
|
||||
return
|
||||
timeout = 2000.0
|
||||
with timeout_guard.CompleteInTimeOrDie(timeout):
|
||||
workspace.RunNet(self.test_model.net.Proto().name)
|
||||
|
||||
# @abstractmethod
|
||||
def planning_output(self):
|
||||
self.init_metrics()
|
||||
@ -335,7 +293,7 @@ class AnyExpTrainer(object):
|
||||
|
||||
def prep_a_data_parallel_model(self, model, dataset, is_train):
|
||||
if model is None:
|
||||
pass
|
||||
return
|
||||
|
||||
log.info('in prep_a_data_parallel_model')
|
||||
|
||||
@ -376,7 +334,17 @@ class AnyExpTrainer(object):
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
log.info('in prep_a_data_parallel_model RunNetOnce done ')
|
||||
|
||||
# for op in model.net.Proto().op:
|
||||
# log.info('op type engine {} {}'.format(op.type, op.engine))
|
||||
|
||||
log.info('model.net.Proto() {}'.format(model.net.Proto()))
|
||||
|
||||
workspace.CreateNet(model.net)
|
||||
|
||||
# for op in model.net.Proto().op:
|
||||
# log.info('after CreateNet op type engine {} {}'.
|
||||
# format(op.type, op.engine))
|
||||
|
||||
log.info('in prep_a_data_parallel_model CreateNet done ')
|
||||
|
||||
def loadCheckpoint(self):
|
||||
@ -416,6 +384,7 @@ class AnyExpTrainer(object):
|
||||
log.info('in buildModelAndTrain, trainer_input: {}'.format(str(opts)))
|
||||
log.info("check type self: {}".format(type(self)))
|
||||
log.info("check self dir: {}".format(dir(self)))
|
||||
log.info("check self source: {}".format(self.__dict__))
|
||||
log.info("check self get_input_dataset methods: {}".
|
||||
format(inspect.getsource(self.get_input_dataset)))
|
||||
log.info("check self gen_input_builder_fun method: {}".
|
||||
@ -430,6 +399,8 @@ class AnyExpTrainer(object):
|
||||
format(inspect.getsource(self.gen_optimizer_fun)))
|
||||
log.info("check self assembleAllOutputs method: {}".
|
||||
format(inspect.getsource(self.assembleAllOutputs)))
|
||||
log.info("check self prep_data_parallel_models method: {}".
|
||||
format(inspect.getsource(self.prep_data_parallel_models)))
|
||||
|
||||
self.get_model_input_fun()
|
||||
|
||||
@ -452,7 +423,10 @@ class AnyExpTrainer(object):
|
||||
self.iter_start_time = time.time()
|
||||
|
||||
self.fun_per_iter_b4RunNet(epoch, epoch_iter)
|
||||
self.run_training_net()
|
||||
|
||||
if self.train_model is not None:
|
||||
self.run_training_net()
|
||||
|
||||
self.fun_per_iter_aftRunNetB4Test(epoch, epoch_iter)
|
||||
|
||||
self.iter_end_time = time.time()
|
||||
@ -479,9 +453,7 @@ class AnyExpTrainer(object):
|
||||
|
||||
self.test_loop_start_time = time.time()
|
||||
for _test_iter in range(0, opts['epoch_iter']['num_test_iter']):
|
||||
timeout = 2000.0
|
||||
with timeout_guard.CompleteInTimeOrDie(timeout):
|
||||
workspace.RunNet(self.test_model.net.Proto().name)
|
||||
self.run_testing_net()
|
||||
for key in self.metrics:
|
||||
metric = self.metrics[key]
|
||||
if metric['is_train']:
|
||||
|
@ -5,13 +5,77 @@ from __future__ import unicode_literals
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import caffe2.contrib.playground.AnyExp as AnyExp
|
||||
import caffe2.contrib.playground.checkpoint as checkpoint
|
||||
|
||||
import logging
|
||||
logging.basicConfig()
|
||||
log = logging.getLogger("AnyExpOnTerm")
|
||||
log.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def runShardedTrainLoop(opts, myTrainFun):
|
||||
start_epoch = 0
|
||||
pretrained_model = opts['model_param']['pretrained_model']
|
||||
if pretrained_model != '' and os.path.exists(pretrained_model):
|
||||
# Only want to get start_epoch.
|
||||
start_epoch, prev_checkpointed_lr, best_metric = \
|
||||
checkpoint.initialize_params_from_file(
|
||||
model=None,
|
||||
weights_file=pretrained_model,
|
||||
num_xpus=1,
|
||||
opts=opts,
|
||||
broadcast_computed_param=True,
|
||||
reset_epoch=opts['model_param']['reset_epoch'],
|
||||
)
|
||||
log.info('start epoch: {}'.format(start_epoch))
|
||||
pretrained_model = None if pretrained_model == '' else pretrained_model
|
||||
ret = None
|
||||
|
||||
pretrained_model = ""
|
||||
shard_results = []
|
||||
|
||||
for epoch in range(start_epoch,
|
||||
opts['epoch_iter']['num_epochs'],
|
||||
opts['epoch_iter']['num_epochs_per_flow_schedule']):
|
||||
# must support checkpoint or the multiple schedule will always
|
||||
# start from initial state
|
||||
checkpoint_model = None if epoch == start_epoch else ret['model']
|
||||
pretrained_model = None if epoch > start_epoch else pretrained_model
|
||||
shard_results = []
|
||||
# with LexicalContext('epoch{}_gang'.format(epoch),gang_schedule=False):
|
||||
for shard_id in range(opts['distributed']['num_shards']):
|
||||
opts['temp_var']['shard_id'] = shard_id
|
||||
opts['temp_var']['pretrained_model'] = pretrained_model
|
||||
opts['temp_var']['checkpoint_model'] = checkpoint_model
|
||||
opts['temp_var']['epoch'] = epoch
|
||||
opts['temp_var']['start_epoch'] = start_epoch
|
||||
shard_ret = myTrainFun(opts)
|
||||
shard_results.append(shard_ret)
|
||||
|
||||
ret = None
|
||||
# always only take shard_0 return
|
||||
for shard_ret in shard_results:
|
||||
if shard_ret is not None:
|
||||
ret = shard_ret
|
||||
opts['temp_var']['metrics_output'] = ret['metrics']
|
||||
break
|
||||
log.info('ret is: {}'.format(str(ret)))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def trainFun():
|
||||
def simpleTrainFun(opts):
|
||||
trainerClass = AnyExp.createTrainerClass(opts)
|
||||
trainerClass = AnyExp.overrideAdditionalMethods(trainerClass, opts)
|
||||
trainer = trainerClass(opts)
|
||||
return trainer.buildModelAndTrain(opts)
|
||||
return simpleTrainFun
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='Any Experiment training.')
|
||||
@ -29,6 +93,6 @@ if __name__ == '__main__':
|
||||
|
||||
# defined this way so that AnyExp.trainFun(opts) can be replaced with
|
||||
# some other custermized training function.
|
||||
ret = AnyExp.runShardedTrainLoop(opts, AnyExp.trainFun())
|
||||
ret = runShardedTrainLoop(opts, trainFun())
|
||||
|
||||
log.info('ret is: {}'.format(str(ret)))
|
||||
|
@ -78,8 +78,22 @@ def constructTrainerClass(myTrainerClass, opts):
|
||||
return myTrainerClass
|
||||
|
||||
|
||||
def overrideAdditionalMethods(myTrainerClass, opts):
|
||||
log.info("B4 additional override myTrainerClass source {}".
|
||||
format(inspect.getsource(myTrainerClass)))
|
||||
# override any additional modules
|
||||
myAdditionalOverride = getModule(opts['model']['additional_override_py'])
|
||||
if myAdditionalOverride is not None:
|
||||
for funcName, funcValue in inspect.getmembers(myAdditionalOverride,
|
||||
inspect.isfunction):
|
||||
setattr(myTrainerClass, funcName, funcValue)
|
||||
log.info("Aft additional override myTrainerClass's source {}".
|
||||
format(inspect.getsource(myTrainerClass)))
|
||||
return myTrainerClass
|
||||
|
||||
|
||||
def getModule(moduleName):
|
||||
log.info("MODULE_MAPS content {}".format(str(MODULE_MAPS)))
|
||||
log.info("get module {} from MODULE_MAPS content {}".format(moduleName, str(MODULE_MAPS)))
|
||||
myModule = None
|
||||
for ModuleMap in MODULE_MAPS:
|
||||
log.info("iterate through MODULE_MAPS content {}".
|
||||
|
@ -41,13 +41,14 @@ $ python caffe2/contrib/playground/AnyExpOnTerm.py --parameters-json '{
|
||||
"forward_pass_py":"caffe2_resnet50_default_forward",
|
||||
"parameter_update_py":"explicit_resnet_param_update",
|
||||
"optimizer_py":"",
|
||||
"rendezvous_py":"rendezvous_filestore"},
|
||||
"rendezvous_py":"rendezvous_filestore",
|
||||
"additional_override_py":""},
|
||||
|
||||
"model_param":{
|
||||
"pretrained_model":"", "reset_epoch":true, "memonger" : true, "cuda_nccl": true,
|
||||
"combine_spatial_bn":true, "max_concurrent_distributed_ops" : 16,
|
||||
"base_learning_rate":0.05, "bn_epsilon":0.00001, "bn_momentum":0.9, "custom_bn_init": true,
|
||||
"bn_init_gamma":1e-323, "weight_decay":1e-4, "weight_decay_bn":1e-323},
|
||||
"bn_init_gamma":1e-323, "weight_decay":1e-4, "weight_decay_bn":1e-323, "engine":"CUDNN"},
|
||||
|
||||
"epoch_iter":{
|
||||
"num_train_sample_per_epoch":10240,
|
||||
@ -149,3 +150,5 @@ $ python caffe2/contrib/playground/AnyExpOnTerm.py --parameters-json '{
|
||||
6. In the demo, the opts item “gen_output_py” uses output_generator.py , which provides a minimum way to generating final experimental result, stored in the form of a dict. It will allow user to do whatever visualization with these data after the training is finished.
|
||||
|
||||
7. Customize your experimental result. A meter interface is provided to implement your own metrics calculators. Example compute_loss.py and compute_topk_accuracy.py. For training metrics, results are calculated right away in each iteration. For testing metrics, results are accumulated for the whole loop and finally calculated after test iteration finishes. Once your have your meter class defined, you can start defining what metrics to report in your opts['output']['metrics'] list. The name you give to your metrics can later be used when you define your plots. The Playground will always record throughput metrics secs_per_train and samples_per_sec.
|
||||
|
||||
8. an additional_override_py option is provided for the modules to allow user override any existing methods defined in the main framework AnyExp.py. This make it easy to shut down part of the model to focus on remaining modules for experimenting or debugging. An example is given as override_no_test_model_no_checkpoint.py, which turns off checkpointing and does neither prepare nor run test model.
|
||||
|
@ -123,6 +123,9 @@ def broadcast_parameters(opts, model, num_xpus, broadcast_computed_param=False):
|
||||
|
||||
def save_model_params(is_checkpoint, model, checkpoint_path, epoch, opts, best_metric):
|
||||
# best_metric=float('-inf')
|
||||
if checkpoint_path is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
save_model_params_blob(
|
||||
model, checkpoint_path, epoch, opts, best_metric
|
||||
|
@ -11,6 +11,13 @@ import caffe2.contrib.playground.resnetdemo.\
|
||||
import caffe2.contrib.playground.resnetdemo.\
|
||||
IN1k_resnet as IN1k_resnet # noqa
|
||||
|
||||
import caffe2.contrib.playground.resnetdemo.\
|
||||
IN1k_resnet_no_test_model as IN1k_resnet_no_test_model # noqa
|
||||
|
||||
# Additional override
|
||||
import caffe2.contrib.playground.resnetdemo.\
|
||||
override_no_test_model_no_checkpoint as override_no_test_model_no_checkpoint # noqa
|
||||
|
||||
# FORWARD_PASS
|
||||
import caffe2.contrib.playground.resnetdemo.\
|
||||
caffe2_resnet50_default_forward as caffe2_resnet50_default_forward # noqa
|
||||
|
@ -22,7 +22,7 @@ def init_model(self):
|
||||
test_model = cnn.CNNModelHelper(
|
||||
order="NCHW",
|
||||
name="resnet_test",
|
||||
use_cudnn=False,
|
||||
use_cudnn=True,
|
||||
cudnn_exhaustive_search=False,
|
||||
init_params=False,
|
||||
)
|
||||
|
@ -0,0 +1,62 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
|
||||
from caffe2.python import workspace, cnn, core
|
||||
from caffe2.python import timeout_guard
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
|
||||
def init_model(self):
|
||||
# if cudnn needs to be turned off, several other places
|
||||
# need to be modified:
|
||||
# 1. operators need to be constructed with engine option, like below:
|
||||
# conv_blob = model.Conv(...engine=engine)
|
||||
# 2. when launch model, opts['model_param']['engine'] = "" instead of "CUDNN"
|
||||
# 2. caffe2_disable_implicit_engine_preference in operator.cc set to true
|
||||
train_model = cnn.CNNModelHelper(
|
||||
order="NCHW",
|
||||
name="resnet",
|
||||
use_cudnn=False,
|
||||
cudnn_exhaustive_search=False,
|
||||
)
|
||||
self.train_model = train_model
|
||||
|
||||
# test_model = cnn.CNNModelHelper(
|
||||
# order="NCHW",
|
||||
# name="resnet_test",
|
||||
# use_cudnn=False,
|
||||
# cudnn_exhaustive_search=False,
|
||||
# init_params=False,
|
||||
# )
|
||||
self.test_model = None
|
||||
|
||||
self.log.info("Model creation completed")
|
||||
|
||||
|
||||
def fun_per_epoch_b4RunNet(self, epoch):
|
||||
pass
|
||||
|
||||
|
||||
def fun_per_iter_b4RunNet(self, epoch, epoch_iter):
|
||||
learning_rate = 0.05
|
||||
for idx in range(self.opts['distributed']['first_xpu_id'],
|
||||
self.opts['distributed']['first_xpu_id'] +
|
||||
self.opts['distributed']['num_xpus']):
|
||||
caffe2_pb2_device = caffe2_pb2.CUDA if \
|
||||
self.opts['distributed']['device'] == 'gpu' else \
|
||||
caffe2_pb2.CPU
|
||||
with core.DeviceScope(core.DeviceOption(caffe2_pb2_device, idx)):
|
||||
workspace.FeedBlob(
|
||||
'{}_{}/lr'.format(self.opts['distributed']['device'], idx),
|
||||
np.array(learning_rate, dtype=np.float32)
|
||||
)
|
||||
|
||||
|
||||
def run_training_net(self):
|
||||
timeout = 2000.0
|
||||
with timeout_guard.CompleteInTimeOrDie(timeout):
|
||||
workspace.RunNet(self.train_model.net.Proto().name)
|
@ -41,6 +41,7 @@ def gen_forward_pass_builder_fun(self, model, dataset, is_train):
|
||||
def resnet_imagenet_create_model(model, data, labels, split, opts, dataset):
|
||||
model_helper = ResNetModelHelper(model, split, opts)
|
||||
opts_depth = opts['model_param']['num_layer']
|
||||
engine = opts['model_param']['engine']
|
||||
log.info(' | ResNet-{} Imagenet'.format(opts_depth))
|
||||
assert opts_depth in BLOCK_CONFIG.keys(), \
|
||||
'Block config is not defined for specified model depth. Please check.'
|
||||
@ -55,7 +56,7 @@ def resnet_imagenet_create_model(model, data, labels, split, opts, dataset):
|
||||
num_classes = 1000
|
||||
conv_blob = model.Conv(
|
||||
data, 'conv1', 3, 64, 7, stride=2, pad=3, weight_init=('MSRAFill', {}),
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=0
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=0, engine=engine
|
||||
)
|
||||
test_mode = False
|
||||
if split in ['test', 'val']:
|
||||
@ -137,6 +138,8 @@ class ResNetModelHelper():
|
||||
self.model = model
|
||||
self.split = split
|
||||
self.opts = opts
|
||||
self.engine = opts['model_param']['engine']
|
||||
|
||||
|
||||
# shortcut type B
|
||||
def add_shortcut(self, blob_in, dim_in, dim_out, stride, prefix):
|
||||
@ -146,7 +149,7 @@ class ResNetModelHelper():
|
||||
blob_in, prefix, dim_in, dim_out, kernel=1,
|
||||
stride=stride,
|
||||
weight_init=("MSRAFill", {}),
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=1
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=1, engine=self.engine
|
||||
)
|
||||
test_mode = False
|
||||
if self.split in ['test', 'val']:
|
||||
@ -168,7 +171,7 @@ class ResNetModelHelper():
|
||||
blob_in, prefix, dim_in, dim_out, kernel, stride=stride,
|
||||
pad=pad, group=group,
|
||||
weight_init=("MSRAFill", {}),
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=1
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=1, engine=self.engine
|
||||
)
|
||||
test_mode = False
|
||||
if self.split in ['test', 'val']:
|
||||
@ -201,7 +204,7 @@ class ResNetModelHelper():
|
||||
conv_blob = self.model.GroupConv_Deprecated(
|
||||
blob_out, prefix + "_branch2b", dim_inner, dim_inner, kernel=3,
|
||||
stride=stride, pad=1, group=group, weight_init=("MSRAFill", {}),
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=1
|
||||
bias_init=('ConstantFill', {'value': 0.}), no_bias=1, engine=self.engine
|
||||
)
|
||||
test_mode = False
|
||||
if self.split in ['test', 'val']:
|
||||
|
@ -0,0 +1,16 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
def checkpoint(self, epoch):
|
||||
self.model_path = None
|
||||
pass
|
||||
|
||||
def prep_data_parallel_models(self):
|
||||
# only do train_model no test needed here
|
||||
self.prep_a_data_parallel_model(self.train_model,
|
||||
self.train_dataset, True)
|
||||
|
||||
def run_testing_net(self):
|
||||
pass
|
@ -42,6 +42,9 @@ class Blob {
|
||||
}
|
||||
|
||||
Blob& operator=(Blob&& other) noexcept {
|
||||
if (pointer_ && destroy_) {
|
||||
destroy_(pointer_);
|
||||
}
|
||||
meta_ = std::move(other.meta_);
|
||||
pointer_ = std::move(other.pointer_);
|
||||
destroy_ = std::move(other.destroy_);
|
||||
|
@ -49,16 +49,4 @@ bool GlobalInit(int* pargc, char*** pargv) {
|
||||
// TODO: if we fail GlobalInit(), should we continue?
|
||||
return success;
|
||||
}
|
||||
|
||||
#if CAFFE2_MOBILE
|
||||
bool GlobalInit() {
|
||||
// On mobile devices, run global init here, since we cannot pass the
|
||||
// command line options to caffe2, no arguments are passed.
|
||||
int mobile_argc = 1;
|
||||
char caffe2_name[] = "caffe2";
|
||||
char* mobile_name = &caffe2_name[0];
|
||||
char** mobile_argv = &mobile_name;
|
||||
return ::caffe2::GlobalInit(&mobile_argc, &mobile_argv);
|
||||
}
|
||||
#endif
|
||||
} // namespace caffe2
|
||||
|
@ -96,8 +96,5 @@ class InitRegisterer {
|
||||
*/
|
||||
bool GlobalInit(int* pargc, char*** argv);
|
||||
|
||||
#if CAFFE2_MOBILE
|
||||
bool GlobalInit();
|
||||
#endif
|
||||
} // namespace caffe2
|
||||
#endif // CAFFE2_CORE_INIT_H_
|
||||
|
@ -9,6 +9,12 @@
|
||||
#include "caffe2/core/timer.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
|
||||
CAFFE2_DEFINE_string(
|
||||
caffe2_override_executor,
|
||||
"",
|
||||
"Comma-separated list of executor overrides");
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -89,10 +95,28 @@ bool NetBase::RunAsync() {
|
||||
}
|
||||
|
||||
namespace {
|
||||
const std::string kSimpleNet = "simple";
|
||||
|
||||
std::vector<NetObserverCreator>* GetNetObserverCreators() {
|
||||
static std::vector<NetObserverCreator> creators;
|
||||
return &creators;
|
||||
}
|
||||
|
||||
void checkExecutorOverride(std::string& net_type) {
|
||||
auto executors = caffe2::split(',', FLAGS_caffe2_override_executor);
|
||||
CAFFE_ENFORCE(
|
||||
executors.size() % 2 == 0, "Invalid override executors flag value");
|
||||
std::unordered_map<std::string, std::string> overrides;
|
||||
for (auto idx = 0; idx < executors.size() - 1; idx += 2) {
|
||||
overrides[executors[idx]] = executors[idx + 1];
|
||||
}
|
||||
if (overrides.count(net_type)) {
|
||||
LOG(INFO) << "Overrode net type '" << net_type << "' with '"
|
||||
<< overrides[net_type] << "'";
|
||||
net_type = overrides[net_type];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AddGlobalNetObserverCreator(NetObserverCreator creator) {
|
||||
@ -113,14 +137,19 @@ unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
|
||||
unique_ptr<NetBase> CreateNet(
|
||||
const std::shared_ptr<const NetDef>& net_def,
|
||||
Workspace* ws) {
|
||||
// In default, we will return a simple network that just runs all operators
|
||||
// sequentially.
|
||||
unique_ptr<NetBase> net;
|
||||
if (!net_def->has_type()) {
|
||||
net = std::unique_ptr<NetBase>(new SimpleNet(net_def, ws));
|
||||
std::string net_type;
|
||||
if (net_def->has_type()) {
|
||||
net_type = net_def->type();
|
||||
} else {
|
||||
net = NetRegistry()->Create(net_def->type(), net_def, ws);
|
||||
// By default, we will return a simple network that just runs all operators
|
||||
// sequentially.
|
||||
net_type = kSimpleNet;
|
||||
}
|
||||
if (!FLAGS_caffe2_override_executor.empty()) {
|
||||
checkExecutorOverride(net_type);
|
||||
}
|
||||
unique_ptr<NetBase> net = NetRegistry()->Create(net_type, net_def, ws);
|
||||
|
||||
VLOG(1) << "Adding a global observer to a net";
|
||||
if (net) {
|
||||
auto* observer_creators = GetNetObserverCreators();
|
||||
|
@ -21,6 +21,8 @@
|
||||
#include "caffe2/utils/simple_queue.h"
|
||||
#include "caffe2/utils/thread_pool.h"
|
||||
|
||||
CAFFE2_DECLARE_string(caffe2_override_executor);
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
class NetBase;
|
||||
@ -56,12 +58,16 @@ class NetBase : public Observable<NetBase> {
|
||||
return false;
|
||||
}
|
||||
Wait();
|
||||
handleRunError();
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void handleRunError() {
|
||||
for (const Event* event : events_) {
|
||||
if (event->Query() != EventStatus::EVENT_SUCCESS) {
|
||||
CAFFE_THROW(event->ErrorMessage());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool RunAsync();
|
||||
|
@ -74,6 +74,16 @@ AsyncNetBase::AsyncNetBase(
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncNetBase::handleRunError() {
|
||||
#ifdef CAFFE2_USE_EXCEPTION_PTR
|
||||
std::unique_lock<std::mutex> exception_lock(exception_mutex_);
|
||||
if (caught_exception_) {
|
||||
std::rethrow_exception(caught_exception_);
|
||||
}
|
||||
#endif // CAFFE2_USE_EXCEPTION_PTR
|
||||
NetBase::handleRunError();
|
||||
}
|
||||
|
||||
bool AsyncNetBase::RunAsync() {
|
||||
tracing::startIter(tracer_);
|
||||
for (auto& op : GetOperators()) {
|
||||
@ -145,7 +155,8 @@ bool AsyncNetBase::isStreamFree(int task_id, int stream_id) const {
|
||||
|
||||
bool AsyncNetBase::canSchedule(
|
||||
int task_id,
|
||||
const std::vector<EventStatus>* status) {
|
||||
const std::vector<EventStatus>* status,
|
||||
bool* parent_failed) {
|
||||
auto first_child_op_id = chains_[task_id].front();
|
||||
for (auto parent_id : parents(task_id)) {
|
||||
auto last_parent_op_id = chains_[parent_id].back();
|
||||
@ -155,6 +166,11 @@ bool AsyncNetBase::canSchedule(
|
||||
} else {
|
||||
parent_status = operators_[last_parent_op_id]->event().Query();
|
||||
}
|
||||
|
||||
if (parent_status == EventStatus::EVENT_FAILED && parent_failed) {
|
||||
*parent_failed = true;
|
||||
}
|
||||
|
||||
bool can_schedule = Event::CanSchedule(
|
||||
operators_[last_parent_op_id]->event().GetType(),
|
||||
parent_status,
|
||||
@ -210,6 +226,15 @@ void AsyncNetBase::asyncWait(
|
||||
first_op->WaitEvents(events, stream_id);
|
||||
}
|
||||
|
||||
void AsyncNetBase::storeExceptionPtr() {
|
||||
#ifdef CAFFE2_USE_EXCEPTION_PTR
|
||||
std::unique_lock<std::mutex> exception_lock(exception_mutex_);
|
||||
if (!caught_exception_) {
|
||||
caught_exception_ = std::current_exception();
|
||||
}
|
||||
#endif // CAFFE2_USE_EXCEPTION_PTR
|
||||
}
|
||||
|
||||
void AsyncNetBase::run(int task_id, int stream_id) {
|
||||
std::string err_msg;
|
||||
for (auto& op_id : chains_[task_id]) {
|
||||
@ -224,13 +249,29 @@ void AsyncNetBase::run(int task_id, int stream_id) {
|
||||
stream_id);
|
||||
CAFFE_ENFORCE(op->RunAsync(stream_id), "Failed to execute an op");
|
||||
} catch (const std::exception& e) {
|
||||
CAFFE_THROW(
|
||||
std::string(e.what()) + ", op " +
|
||||
(op->has_debug_def() ? op->type() : " unknown"));
|
||||
#ifdef CAFFE2_USE_EXCEPTION_PTR
|
||||
storeExceptionPtr();
|
||||
#endif // CAFFE2_USE_EXCEPTION_PTR
|
||||
auto err_msg = std::string(e.what()) + ", op " +
|
||||
(op->has_debug_def() ? op->type() : " unknown");
|
||||
if (query(task_id) == EventStatus::EVENT_INITIALIZED) {
|
||||
// mark the chain's event as failed,
|
||||
// not throwing because event is in initialized state
|
||||
event(task_id).SetFinished(err_msg.c_str());
|
||||
}
|
||||
LOG(ERROR) << err_msg;
|
||||
throw;
|
||||
} catch (...) {
|
||||
CAFFE_THROW(
|
||||
"Failed to execute task: unknown error, op " +
|
||||
(op->has_debug_def() ? op->type() : " unknown"));
|
||||
#ifdef CAFFE2_USE_EXCEPTION_PTR
|
||||
storeExceptionPtr();
|
||||
#endif // CAFFE2_USE_EXCEPTION_PTR
|
||||
auto err_msg = "Failed to execute task: unknown error, op " +
|
||||
(op->has_debug_def() ? op->type() : " unknown");
|
||||
if (query(task_id) == EventStatus::EVENT_INITIALIZED) {
|
||||
event(task_id).SetFinished(err_msg.c_str());
|
||||
}
|
||||
LOG(ERROR) << err_msg;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,12 +34,15 @@ class AsyncNetBase : public NetBase {
|
||||
return operators_;
|
||||
}
|
||||
|
||||
void handleRunError() override;
|
||||
|
||||
bool RunAsync() override;
|
||||
|
||||
protected:
|
||||
bool canSchedule(
|
||||
int chain_id,
|
||||
const std::vector<EventStatus>* status = nullptr);
|
||||
const std::vector<EventStatus>* status = nullptr,
|
||||
bool* parent_failed = nullptr);
|
||||
|
||||
int tasksNum() const;
|
||||
Event& event(int task_id) const;
|
||||
@ -78,12 +81,19 @@ class AsyncNetBase : public NetBase {
|
||||
static thread_local std::vector<int> stream_counters_;
|
||||
int num_workers_;
|
||||
|
||||
#ifdef CAFFE2_USE_EXCEPTION_PTR
|
||||
// Mutex that protects caught_exception_
|
||||
std::mutex exception_mutex_;
|
||||
std::exception_ptr caught_exception_;
|
||||
#endif // CAFFE2_USE_EXCEPTION_PTR
|
||||
// Tracing
|
||||
std::shared_ptr<tracing::Tracer> tracer_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
|
||||
|
||||
private:
|
||||
void storeExceptionPtr();
|
||||
|
||||
std::shared_ptr<TaskThreadPool>
|
||||
pool_getter(PoolsMap& pools, int device_type, int device_id, int pool_size);
|
||||
|
||||
|
@ -24,7 +24,6 @@ void AsyncSchedulingNet::reset() {
|
||||
auto& task_op_node = operator_nodes_[task_ops.front()];
|
||||
task_op_node.runtime_parent_count_ = parents(task_id).size();
|
||||
}
|
||||
exception_messages_.clear();
|
||||
}
|
||||
|
||||
void AsyncSchedulingNet::Wait() {
|
||||
@ -43,8 +42,6 @@ void AsyncSchedulingNet::schedule(int task_id) {
|
||||
try {
|
||||
run(task_id, stream_id);
|
||||
} catch (const std::exception& e) {
|
||||
std::unique_lock<std::mutex> lock(exception_mutex_);
|
||||
exception_messages_.push_back(e.what());
|
||||
success_ = false;
|
||||
}
|
||||
}
|
||||
@ -54,7 +51,8 @@ void AsyncSchedulingNet::schedule(int task_id) {
|
||||
for (auto child_id : children(task_id)) {
|
||||
int parent_count = updateParentCount(child_id);
|
||||
if (parent_count == 0) {
|
||||
if (cleanup_ || FLAGS_caffe2_net_async_always_schedule_child ||
|
||||
if (!success_ || cleanup_ ||
|
||||
FLAGS_caffe2_net_async_always_schedule_child ||
|
||||
canSchedule(child_id)) {
|
||||
schedule(child_id);
|
||||
} else {
|
||||
@ -103,8 +101,16 @@ void AsyncSchedulingNet::schedule(int task_id) {
|
||||
}
|
||||
|
||||
void AsyncSchedulingNet::pollAndSchedule(int task_id) {
|
||||
if (canSchedule(task_id) || cleanup_) {
|
||||
// force schedule the rest of the tasks if cleanup is started
|
||||
bool parent_failed = false;
|
||||
bool can_schedule = canSchedule(task_id, nullptr, &parent_failed);
|
||||
if (parent_failed) {
|
||||
success_ = false;
|
||||
}
|
||||
// schedule the task if:
|
||||
// - parents are ready
|
||||
// - we failed / cleanup started (no ops will run)
|
||||
|
||||
if (can_schedule || cleanup_ || !success_ || parent_failed) {
|
||||
schedule(task_id);
|
||||
} else {
|
||||
const auto& device_option = event(task_id).GetDeviceOption();
|
||||
@ -142,6 +148,10 @@ bool AsyncSchedulingNet::DoRunAsync() {
|
||||
}
|
||||
}
|
||||
|
||||
if (tasksNum() == 0) {
|
||||
finishRun();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -32,8 +32,6 @@ class AsyncSchedulingNet : public AsyncNetBase {
|
||||
std::atomic<bool> cleanup_;
|
||||
|
||||
std::atomic<int> processed_tasks_num_;
|
||||
std::mutex exception_mutex_;
|
||||
std::vector<std::string> exception_messages_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(AsyncSchedulingNet);
|
||||
};
|
||||
|
@ -1,9 +1,12 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/net_async_scheduling.h"
|
||||
#include "caffe2/core/net_dag.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/scope_guard.h"
|
||||
|
||||
#include <google/protobuf/text_format.h>
|
||||
|
||||
CAFFE2_DECLARE_bool(caffe2_disable_chaining);
|
||||
|
||||
namespace caffe2 {
|
||||
@ -85,7 +88,7 @@ unique_ptr<NetBase> CreateNetTestHelper(
|
||||
return CreateNet(net_def, ws);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
TEST(NetTest, ConstructionNoDeclaredInputOutput) {
|
||||
Workspace ws;
|
||||
@ -115,8 +118,7 @@ TEST(NetTest, DeclaredInputInsufficient) {
|
||||
Workspace ws;
|
||||
ws.CreateBlob("in");
|
||||
ASSERT_THROW(
|
||||
CreateNetTestHelper(&ws, vector<string>{"unuseful_in"},
|
||||
vector<string>()),
|
||||
CreateNetTestHelper(&ws, vector<string>{"unuseful_in"}, vector<string>()),
|
||||
EnforceNotMet);
|
||||
}
|
||||
|
||||
@ -124,8 +126,8 @@ TEST(NetDeathTest, DeclaredOutputNotMet) {
|
||||
Workspace ws;
|
||||
ws.CreateBlob("in");
|
||||
ASSERT_THROW(
|
||||
CreateNetTestHelper(&ws, vector<string>(),
|
||||
vector<string>{"unproduced_out"}),
|
||||
CreateNetTestHelper(
|
||||
&ws, vector<string>(), vector<string>{"unproduced_out"}),
|
||||
EnforceNotMet);
|
||||
}
|
||||
|
||||
@ -144,7 +146,8 @@ void checkChainingAndRun(
|
||||
Workspace ws;
|
||||
ws.CreateBlob("in");
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def));
|
||||
CAFFE_ENFORCE(
|
||||
::google::protobuf::TextFormat::ParseFromString(spec, &net_def));
|
||||
{
|
||||
net_def.set_num_workers(4);
|
||||
auto old = FLAGS_caffe2_disable_chaining;
|
||||
@ -164,7 +167,8 @@ void checkNumChainsAndRun(const char* spec, const int expected_num_chains) {
|
||||
Workspace ws;
|
||||
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def));
|
||||
CAFFE_ENFORCE(
|
||||
::google::protobuf::TextFormat::ParseFromString(spec, &net_def));
|
||||
net_def.set_num_workers(4);
|
||||
|
||||
// Create all external inputs
|
||||
@ -563,7 +567,8 @@ TEST(NetTest, FailingOperator) {
|
||||
ws.CreateBlob("in");
|
||||
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def));
|
||||
CAFFE_ENFORCE(
|
||||
::google::protobuf::TextFormat::ParseFromString(spec, &net_def));
|
||||
|
||||
{
|
||||
net_def.set_num_workers(4);
|
||||
@ -574,7 +579,7 @@ TEST(NetTest, FailingOperator) {
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
for (int i = 0; i < 10; i++) {
|
||||
counter.exchange(0);
|
||||
ASSERT_FALSE(net.get()->Run());
|
||||
ASSERT_FALSE(net->Run());
|
||||
ASSERT_EQ(1, counter.load());
|
||||
}
|
||||
}
|
||||
@ -612,7 +617,8 @@ TEST(NetTest, OperatorWithExecutorHelper) {
|
||||
)DOC";
|
||||
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def));
|
||||
CAFFE_ENFORCE(
|
||||
::google::protobuf::TextFormat::ParseFromString(spec, &net_def));
|
||||
|
||||
Workspace ws;
|
||||
net_def.set_num_workers(kTestPoolSize);
|
||||
@ -620,4 +626,98 @@ TEST(NetTest, OperatorWithExecutorHelper) {
|
||||
ASSERT_TRUE(net->Run());
|
||||
}
|
||||
|
||||
TEST(NetTest, OperatorWithDisabledEvent) {
|
||||
const auto spec = R"DOC(
|
||||
name: "example"
|
||||
type: "async_scheduling"
|
||||
external_input: "in"
|
||||
op {
|
||||
input: "in"
|
||||
output: "out"
|
||||
type: "NetTestDummy"
|
||||
arg {
|
||||
name: "fail"
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
)DOC";
|
||||
|
||||
Workspace ws;
|
||||
ws.CreateBlob("in");
|
||||
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(
|
||||
::google::protobuf::TextFormat::ParseFromString(spec, &net_def));
|
||||
|
||||
{
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
net->GetOperators()[0]->DisableEvent();
|
||||
// async_scheduling propagates exception
|
||||
bool caught_exception = false;
|
||||
try {
|
||||
net->Run();
|
||||
} catch (const std::exception& e) {
|
||||
caught_exception = true;
|
||||
}
|
||||
ASSERT_TRUE(caught_exception);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NetTest, ExecutorOverride) {
|
||||
const auto spec = R"DOC(
|
||||
name: "example"
|
||||
type: "dag"
|
||||
)DOC";
|
||||
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(
|
||||
::google::protobuf::TextFormat::ParseFromString(spec, &net_def));
|
||||
|
||||
{
|
||||
Workspace ws;
|
||||
auto old = FLAGS_caffe2_override_executor;
|
||||
auto g = MakeGuard([&]() { FLAGS_caffe2_override_executor = old; });
|
||||
FLAGS_caffe2_override_executor = "";
|
||||
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
auto dag_net = caffe2::dynamic_cast_if_rtti<DAGNet*>(net.get());
|
||||
ASSERT_TRUE(dag_net != nullptr);
|
||||
}
|
||||
|
||||
{
|
||||
Workspace ws;
|
||||
auto old = FLAGS_caffe2_override_executor;
|
||||
auto g = MakeGuard([&]() { FLAGS_caffe2_override_executor = old; });
|
||||
FLAGS_caffe2_override_executor = "dag,async_scheduling";
|
||||
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
auto async_net =
|
||||
caffe2::dynamic_cast_if_rtti<AsyncSchedulingNet*>(net.get());
|
||||
ASSERT_TRUE(async_net != nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NetTest, AsyncEmptyNet) {
|
||||
const auto spec = R"DOC(
|
||||
name: "example"
|
||||
type: "async_scheduling"
|
||||
)DOC";
|
||||
|
||||
Workspace ws;
|
||||
NetDef net_def;
|
||||
CAFFE_ENFORCE(
|
||||
::google::protobuf::TextFormat::ParseFromString(spec, &net_def));
|
||||
|
||||
{
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
bool caught_exception = false;
|
||||
try {
|
||||
ASSERT_TRUE(net->Run());
|
||||
} catch (const std::exception& e) {
|
||||
caught_exception = true;
|
||||
}
|
||||
ASSERT_FALSE(caught_exception);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -30,8 +30,6 @@ class ObserverBase {
|
||||
|
||||
virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order)
|
||||
const {
|
||||
LOG(WARNING)
|
||||
<< "rnnCopy() is not implemented and nullptr will be returned.";
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
|
@ -122,7 +122,9 @@ class OperatorBase : public Observable<OperatorBase> {
|
||||
}
|
||||
|
||||
inline void Wait(const OperatorBase& other, int stream_id = -1) {
|
||||
WaitEvent(other.event(), stream_id);
|
||||
if (!other.IsEventDisabled()) {
|
||||
WaitEvent(other.event(), stream_id);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void WaitEvents(
|
||||
@ -162,20 +164,20 @@ class OperatorBase : public Observable<OperatorBase> {
|
||||
if (HasAsyncPart()) {
|
||||
RecordEvent();
|
||||
} else {
|
||||
event().SetFinished();
|
||||
SetEventFinished();
|
||||
}
|
||||
} else {
|
||||
event().SetFinished(getErrorMsg().c_str());
|
||||
SetEventFinished(getErrorMsg().c_str());
|
||||
}
|
||||
return result;
|
||||
} catch (EnforceNotMet& err) {
|
||||
event().SetFinished(err.what());
|
||||
SetEventFinished(err.what());
|
||||
throw;
|
||||
} catch (const std::exception& err) {
|
||||
event().SetFinished(err.what());
|
||||
SetEventFinished(err.what());
|
||||
throw;
|
||||
} catch (...) {
|
||||
event().SetFinished(getErrorMsg().c_str());
|
||||
SetEventFinished(getErrorMsg().c_str());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@ -316,6 +318,12 @@ class OperatorBase : public Observable<OperatorBase> {
|
||||
CAFFE_NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
void SetEventFinished(const char* err_msg = nullptr) {
|
||||
if (event_) {
|
||||
event_->SetFinished(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
std::string getErrorMsg() {
|
||||
if (has_debug_def()) {
|
||||
return "Error from operator: " + ProtoDebugString(debug_def());
|
||||
@ -438,10 +446,10 @@ class Operator : public OperatorBase {
|
||||
} else {
|
||||
// Manually set CPU operator's event status to finished,
|
||||
// unless this is an async CPU operator
|
||||
event().SetFinished();
|
||||
SetEventFinished();
|
||||
}
|
||||
} else {
|
||||
event().SetFinished(getErrorMsg().c_str());
|
||||
SetEventFinished(getErrorMsg().c_str());
|
||||
this->RecordLastFailedOpNetPosition();
|
||||
}
|
||||
return result;
|
||||
@ -451,15 +459,15 @@ class Operator : public OperatorBase {
|
||||
"Error from operator: \n" + ProtoDebugString(debug_def()));
|
||||
AddRelatedBlobInfo(&err);
|
||||
}
|
||||
event().SetFinished(err.what());
|
||||
SetEventFinished(err.what());
|
||||
this->RecordLastFailedOpNetPosition();
|
||||
throw;
|
||||
} catch (const std::exception& err) {
|
||||
event().SetFinished(err.what());
|
||||
SetEventFinished(err.what());
|
||||
this->RecordLastFailedOpNetPosition();
|
||||
throw;
|
||||
} catch (...) {
|
||||
event().SetFinished(getErrorMsg().c_str());
|
||||
SetEventFinished(getErrorMsg().c_str());
|
||||
this->RecordLastFailedOpNetPosition();
|
||||
throw;
|
||||
}
|
||||
|
@ -479,10 +479,15 @@ inline vector<TIndex> GetDimsVector(const TensorShape& shape) {
|
||||
inline std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
|
||||
InferOpInputOutputDevice(const OperatorDef& op) {
|
||||
auto op_schema = OpSchemaRegistry::Schema(op.type());
|
||||
CAFFE_ENFORCE(
|
||||
op_schema, "Device inference failed. No schema for: ", op.type());
|
||||
// TODO(wyiming) : add try catch here.
|
||||
return op_schema->InferDevice(op);
|
||||
if (op_schema) {
|
||||
// op_schema found
|
||||
return op_schema->InferDevice(op);
|
||||
|
||||
} else {
|
||||
// No schema for op.type registered
|
||||
auto temp_schema = OpSchema();
|
||||
return temp_schema.InferDevice(op);
|
||||
}
|
||||
}
|
||||
|
||||
template <uint64_t OpsPerPoint>
|
||||
|
@ -1,7 +1,4 @@
|
||||
#include "caffe2/core/predictor.h"
|
||||
#if CAFFE2_MOBILE
|
||||
#include "caffe2/core/init.h"
|
||||
#endif
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
@ -89,9 +86,6 @@ Predictor::Predictor(
|
||||
if (run_init) {
|
||||
CAFFE_ENFORCE(ws_.RunNetOnce(init_net));
|
||||
}
|
||||
#if CAFFE2_MOBILE
|
||||
GlobalInit();
|
||||
#endif
|
||||
|
||||
// real model inputs can be fed later in run* functions
|
||||
const auto& initialized_vec = ws_.Blobs();
|
||||
|
@ -19,6 +19,7 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> splitOpDevInfer(
|
||||
} // namespace.
|
||||
|
||||
REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(SplitByLengths, SplitByLengthsOp<CPUContext>);
|
||||
OPERATOR_SCHEMA(Split)
|
||||
.NumInputs(1, 2)
|
||||
.NumOutputs(1, INT_MAX)
|
||||
@ -36,6 +37,28 @@ to equal sized parts.
|
||||
)DOC")
|
||||
.InheritOnnxSchema("Split");
|
||||
|
||||
OPERATOR_SCHEMA(SplitByLengths)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1, INT_MAX)
|
||||
.Input(0, "input", "The tensor to split")
|
||||
.Input(1, "legnths", "The tensor `l_i` indicates the logic block of input.")
|
||||
.Arg("axis", "Which axis to split on")
|
||||
.Arg("order", "Either NHWC or NCWH, will split on C axis, defaults to NCHW")
|
||||
.DeviceInferenceFunction([](const OperatorDef& def) {
|
||||
auto op_device =
|
||||
def.has_device_option() ? def.device_option() : DeviceOption();
|
||||
vector<DeviceOption> in_dev(def.input_size(), op_device);
|
||||
vector<DeviceOption> out_dev(def.output_size(), op_device);
|
||||
// lengths input should be on CPU
|
||||
in_dev[1] = DeviceOption();
|
||||
return std::make_pair(in_dev, out_dev);
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
Split a tensor into a list of tensors, given a lengths input, along the specified
|
||||
'axis'. If `K` outputs are provided, the op assumes `len(lengths) % K == 0`.
|
||||
The `input` will be split into `K` parts. Each part of length
|
||||
`sum(lengths[i*k:i*k+k))`)DOC");
|
||||
|
||||
namespace {
|
||||
OpSchema::Cost CostInferenceForConcat(
|
||||
const OperatorDef& def,
|
||||
@ -206,6 +229,7 @@ class GetSplitGradient : public GradientMakerBase {
|
||||
};
|
||||
REGISTER_GRADIENT(Split, GetSplitGradient);
|
||||
REGISTER_GRADIENT(DepthSplit, GetSplitGradient);
|
||||
REGISTER_GRADIENT(SplitByLengths, GetSplitGradient);
|
||||
|
||||
class GetConcatGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
|
@ -58,6 +58,35 @@ class SplitOp final : public Operator<Context> {
|
||||
// The split tensor is stored in CPU.
|
||||
};
|
||||
|
||||
template <class Context>
|
||||
class SplitByLengthsOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SplitByLengthsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
CAFFE_ENFORCE(
|
||||
!(OperatorBase::HasArgument("axis") &&
|
||||
OperatorBase::HasArgument("order")),
|
||||
"You shouldn't specify both the dim to split, and the order "
|
||||
"in the case of 4-D images.");
|
||||
if (OperatorBase::HasArgument("axis")) {
|
||||
axis_ = OperatorBase::GetSingleArgument<int>("axis", 0);
|
||||
} else {
|
||||
axis_ = GetDimFromOrderString(
|
||||
OperatorBase::GetSingleArgument<string>("order", "NCHW"));
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
protected:
|
||||
int axis_;
|
||||
Tensor<Context> inclusive_scan_buffer_;
|
||||
Tensor<Context> inclusive_scan_length_buffer_;
|
||||
// Input: X, optionally split
|
||||
// The split tensor is stored in CPU.
|
||||
};
|
||||
|
||||
template <class Context>
|
||||
class ConcatOp final : public Operator<Context> {
|
||||
public:
|
||||
@ -166,6 +195,52 @@ bool SplitOp<Context>::RunOnDevice() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Implementations
|
||||
template <class Context>
|
||||
bool SplitByLengthsOp<Context>::RunOnDevice() {
|
||||
auto& input = Input(0);
|
||||
auto& length = OperatorBase::Input<TensorCPU>(1);
|
||||
auto length_length = length.size();
|
||||
CAFFE_ENFORCE_EQ(
|
||||
length_length % OutputSize(),
|
||||
0,
|
||||
"len(Lengths) should be divisible by OutputSize().");
|
||||
int canonical_axis = input.canonical_axis_index(axis_);
|
||||
CAFFE_ENFORCE_LT(
|
||||
canonical_axis, input.ndim(), "Axis not in input ndim range.");
|
||||
const int input_channels = input.dim32(canonical_axis);
|
||||
const auto* axis_data = length.template data<int>();
|
||||
CAFFE_ENFORCE_EQ(
|
||||
std::accumulate(axis_data, axis_data + length.size(), 0),
|
||||
input_channels,
|
||||
"Sum of split dimensions do not match: should be ",
|
||||
input_channels);
|
||||
vector<TIndex> output_dims(input.dims());
|
||||
int before = input.size_to_dim(canonical_axis);
|
||||
int after = input.size_from_dim(canonical_axis + 1);
|
||||
size_t input_offset = 0;
|
||||
for (int i = 0; i < OutputSize(); ++i) {
|
||||
auto* output = Output(i);
|
||||
const auto* axis_offset = axis_data + length_length / OutputSize() * i;
|
||||
auto axis_dim = std::accumulate(
|
||||
axis_offset, axis_offset + length_length / OutputSize(), 0);
|
||||
output_dims[canonical_axis] = axis_dim;
|
||||
output->Resize(output_dims);
|
||||
math::CopyMatrix<Context>(
|
||||
input.itemsize(),
|
||||
before,
|
||||
axis_dim * after,
|
||||
static_cast<const char*>(input.raw_data()) + input_offset,
|
||||
input.dim32(canonical_axis) * after,
|
||||
output->raw_mutable_data(input.meta()),
|
||||
axis_dim * after,
|
||||
&context_,
|
||||
input.meta().copy());
|
||||
input_offset += axis_dim * after * input.itemsize();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Context>
|
||||
bool ConcatOp<Context>::RunOnDevice() {
|
||||
auto* output = Output(0);
|
||||
|
@ -8,4 +8,6 @@ REGISTER_CUDA_OPERATOR(Concat, ConcatOp<CUDAContext>);
|
||||
// Backward compatibility settings
|
||||
REGISTER_CUDA_OPERATOR(DepthSplit, SplitOp<CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(DepthConcat, ConcatOp<CUDAContext>);
|
||||
} // namespace caffe2
|
||||
|
||||
REGISTER_CUDA_OPERATOR(SplitByLengths, SplitByLengthsOp<CUDAContext>);
|
||||
} // namespace caffe2
|
||||
|
@ -150,10 +150,20 @@ __device__ float sigmoid_xent_backward_with_log_d_trick(float lgt, float tgt) {
|
||||
return (2 * tgt - 1.) / (1. + exp(lgt));
|
||||
}
|
||||
|
||||
__device__ float unjoined_sigmoid_xent_forward(float lgt, float tgt) {
|
||||
return lgt * tgt + (tgt - 1) * lgt * (lgt >= 0) -
|
||||
(1 - tgt) * log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
|
||||
}
|
||||
|
||||
__device__ float unjoined_sigmoid_xent_backward(float lgt, float tgt) {
|
||||
return tgt - (1. - tgt) / (1. + exp(-lgt));
|
||||
}
|
||||
|
||||
__global__ void SigmoidCrossEntropyWithLogitsKernel(
|
||||
const int outer_size,
|
||||
const int inner_size,
|
||||
const bool log_D_trick,
|
||||
const bool unjoined_lr_loss,
|
||||
const float* logits_ptr,
|
||||
const float* targets_ptr,
|
||||
float* out_ptr) {
|
||||
@ -162,11 +172,16 @@ __global__ void SigmoidCrossEntropyWithLogitsKernel(
|
||||
float value = 0;
|
||||
for (int in_idx = i * inner_size + threadIdx.x; in_idx < last_idx;
|
||||
in_idx += blockDim.x) {
|
||||
value +=
|
||||
(log_D_trick
|
||||
? sigmoid_xent_forward_with_log_d_trick(
|
||||
logits_ptr[in_idx], targets_ptr[in_idx])
|
||||
: sigmoid_xent_forward(logits_ptr[in_idx], targets_ptr[in_idx]));
|
||||
if (unjoined_lr_loss) {
|
||||
value += unjoined_sigmoid_xent_forward(
|
||||
logits_ptr[in_idx], targets_ptr[in_idx]);
|
||||
} else {
|
||||
value +=
|
||||
(log_D_trick
|
||||
? sigmoid_xent_forward_with_log_d_trick(
|
||||
logits_ptr[in_idx], targets_ptr[in_idx])
|
||||
: sigmoid_xent_forward(logits_ptr[in_idx], targets_ptr[in_idx]));
|
||||
}
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
|
||||
@ -181,6 +196,7 @@ __global__ void SigmoidCrossEntropyGradientWithLogitsKernel(
|
||||
const int outer_size,
|
||||
const int inner_size,
|
||||
const bool log_D_trick,
|
||||
const bool unjoined_lr_loss,
|
||||
const float* g_ptr,
|
||||
const float* logits_ptr,
|
||||
const float* targets_ptr,
|
||||
@ -188,11 +204,17 @@ __global__ void SigmoidCrossEntropyGradientWithLogitsKernel(
|
||||
CUDA_1D_KERNEL_LOOP(in_idx, outer_size * inner_size) {
|
||||
int i = in_idx / inner_size;
|
||||
auto g_factor = -g_ptr[i] / inner_size;
|
||||
out_ptr[in_idx] = g_factor *
|
||||
(log_D_trick
|
||||
? sigmoid_xent_backward_with_log_d_trick(
|
||||
logits_ptr[in_idx], targets_ptr[in_idx])
|
||||
: sigmoid_xent_backward(logits_ptr[in_idx], targets_ptr[in_idx]));
|
||||
if (unjoined_lr_loss) {
|
||||
out_ptr[in_idx] = g_factor *
|
||||
unjoined_sigmoid_xent_backward(
|
||||
logits_ptr[in_idx], targets_ptr[in_idx]);
|
||||
} else {
|
||||
out_ptr[in_idx] = g_factor *
|
||||
(log_D_trick ? sigmoid_xent_backward_with_log_d_trick(
|
||||
logits_ptr[in_idx], targets_ptr[in_idx])
|
||||
: sigmoid_xent_backward(
|
||||
logits_ptr[in_idx], targets_ptr[in_idx]));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -227,7 +249,13 @@ bool SigmoidCrossEntropyWithLogitsOp<float, CUDAContext>::RunOnDevice() {
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
outer_size, inner_size, log_D_trick_, logits_ptr, targets_ptr, out_ptr);
|
||||
outer_size,
|
||||
inner_size,
|
||||
log_D_trick_,
|
||||
unjoined_lr_loss_,
|
||||
logits_ptr,
|
||||
targets_ptr,
|
||||
out_ptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -258,6 +286,7 @@ bool SigmoidCrossEntropyWithLogitsGradientOp<float, CUDAContext>::
|
||||
outer_size,
|
||||
inner_size,
|
||||
log_D_trick_,
|
||||
unjoined_lr_loss_,
|
||||
g_ptr,
|
||||
logits_ptr,
|
||||
targets_ptr,
|
||||
|
31
caffe2/operators/ensure_cpu_output_op.cc
Normal file
31
caffe2/operators/ensure_cpu_output_op.cc
Normal file
@ -0,0 +1,31 @@
|
||||
#include "caffe2/operators/ensure_cpu_output_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// From CPU Context, the op takes CPU tensor as input, and produces
|
||||
// TensorCPU
|
||||
REGISTER_CPU_OPERATOR(EnsureCPUOutput, EnsureCPUOutputOp<CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(EnsureCPUOutput)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.IdenticalTypeAndShape()
|
||||
.InputsCanCrossDevices()
|
||||
.DeviceInferenceFunction([](const OperatorDef& def) {
|
||||
auto op_device =
|
||||
def.has_device_option() ? def.device_option() : DeviceOption();
|
||||
auto cpu_option = DeviceOption();
|
||||
vector<DeviceOption> in_dev(def.input_size(), op_device);
|
||||
vector<DeviceOption> out_dev(def.output_size(), cpu_option);
|
||||
return std::make_pair(in_dev, out_dev);
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
This Op always create TensorCPU output, and may involves cross-device MemCpy.
|
||||
Under CPU Context, this Op takes TensorCPU as input. Under the CUDA Context,
|
||||
this Op accepts either CUDA or CPU Tensor input.
|
||||
)DOC")
|
||||
.Input(0, "input", "The input CUDA or CPU tensor.")
|
||||
.Output(0, "output", "TensorCPU that is a copy of the input.");
|
||||
|
||||
NO_GRADIENT(EnsureCPUOutput);
|
||||
} // namespace caffe2
|
9
caffe2/operators/ensure_cpu_output_op.cu
Normal file
9
caffe2/operators/ensure_cpu_output_op.cu
Normal file
@ -0,0 +1,9 @@
|
||||
#include "caffe2/operators/ensure_cpu_output_op.h"
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
namespace caffe2 {
|
||||
// From CUDA Context, takes either CUDA or CPU tensor as input, and produce
|
||||
// TensorCPU
|
||||
REGISTER_CUDA_OPERATOR(EnsureCPUOutput, EnsureCPUOutputOp<CUDAContext>);
|
||||
} // namespace caffe2
|
50
caffe2/operators/ensure_cpu_output_op.h
Normal file
50
caffe2/operators/ensure_cpu_output_op.h
Normal file
@ -0,0 +1,50 @@
|
||||
#ifndef CAFFE2_OPERATORS_ENSURE_CPU_OUTPUT_OP_H_
|
||||
#define CAFFE2_OPERATORS_ENSURE_CPU_OUTPUT_OP_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class EnsureCPUOutputOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
EnsureCPUOutputOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
if (OperatorBase::InputIsType<TensorCPU>(0)) {
|
||||
return CopyWithContext<CPUContext>();
|
||||
} else if (OperatorBase::InputIsType<Tensor<Context>>(0)) {
|
||||
// CUDA Context will go this branch
|
||||
return CopyWithContext<Context>();
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"Unexpected Input Blob: ",
|
||||
OperatorBase::Inputs().at(0)->meta().name());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
template <class InputContext>
|
||||
bool CopyWithContext() {
|
||||
// Output is always on CPU
|
||||
auto* output = OperatorBase::Output<TensorCPU>(0);
|
||||
auto& input = OperatorBase::Input<Tensor<InputContext>>(0);
|
||||
output->ResizeLike(input);
|
||||
context_.template CopyItems<InputContext, CPUContext>(
|
||||
input.meta(),
|
||||
input.size(),
|
||||
input.raw_data(),
|
||||
output->raw_mutable_data(input.meta()));
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_ENSURE_CPU_OUTPUT_OP_H_
|
@ -8,6 +8,15 @@ REGISTER_CPU_OPERATOR(
|
||||
OPERATOR_SCHEMA(FloatToFused8BitRowwiseQuantized)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.TensorInferenceFunction([](const OperatorDef& /* def */,
|
||||
const vector<TensorShape>& in) {
|
||||
vector<TensorShape> out;
|
||||
TensorShape X = in[0];
|
||||
X.set_dims(1, X.dims(1) + 8);
|
||||
out.push_back(std::move(X));
|
||||
out[0].set_data_type(TensorProto_DataType_UINT8);
|
||||
return out;
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
Applies 8-bit row-wise quantization by determining the range
|
||||
(maximum - minimum) and offset (minimum value) of each row in the input
|
||||
@ -28,6 +37,15 @@ REGISTER_CPU_OPERATOR(
|
||||
OPERATOR_SCHEMA(Fused8BitRowwiseQuantizedToFloat)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.TensorInferenceFunction([](const OperatorDef& /* def */,
|
||||
const vector<TensorShape>& in) {
|
||||
vector<TensorShape> out;
|
||||
TensorShape X = in[0];
|
||||
X.set_dims(1, X.dims(1) - 8);
|
||||
out.push_back(std::move(X));
|
||||
out[0].set_data_type(TensorProto_DataType_FLOAT);
|
||||
return out;
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
De-quantizes the result of the
|
||||
FloatToFused8BitRowwiseQuantized operator. The input is expected to
|
||||
|
117
caffe2/operators/relu_n_op.cc
Normal file
117
caffe2/operators/relu_n_op.cc
Normal file
@ -0,0 +1,117 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/operators/relu_n_op.h"
|
||||
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <>
|
||||
bool ReluNOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
Y->ResizeLike(X);
|
||||
|
||||
EigenVectorMap<float>(Y->mutable_data<float>(), X.size()) =
|
||||
ConstEigenVectorMap<float>(X.data<float>(), X.size())
|
||||
.cwiseMax(0.f)
|
||||
.cwiseMin(n);
|
||||
return true;
|
||||
}
|
||||
|
||||
// define a custom template unary functor
|
||||
template <typename Scalar>
|
||||
struct CwiseClampSignOp {
|
||||
CwiseClampSignOp(const Scalar& sup) : m_sup(sup) {}
|
||||
const Scalar operator()(const Scalar& x) const {
|
||||
return x < 0 ? 0 : (x >= m_sup ? 0 : 1);
|
||||
}
|
||||
Scalar m_sup;
|
||||
};
|
||||
|
||||
template <>
|
||||
bool ReluNGradientOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& Y = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto* dX = Output(0);
|
||||
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
|
||||
dX->ResizeLike(Y);
|
||||
|
||||
const float* Ydata = Y.data<float>();
|
||||
const float* dYdata = dY.data<float>();
|
||||
float* dXdata = dX->mutable_data<float>();
|
||||
// TODO: proper vectorization with Eigen
|
||||
EigenVectorArrayMap<float> dXvec(dXdata, dX->size());
|
||||
ConstEigenVectorArrayMap<float> Yvec(Ydata, Y.size());
|
||||
ConstEigenVectorArrayMap<float> dYvec(dYdata, dY.size());
|
||||
dXvec = dYvec * Yvec.unaryExpr(CwiseClampSignOp<float>(n));
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
OpSchema::Cost CostInferenceForReluN(
|
||||
const OperatorDef& def,
|
||||
const vector<TensorShape>& in) {
|
||||
struct OpSchema::Cost cost = PointwiseCostInference<2>(def, in);
|
||||
cost.params_bytes = 0;
|
||||
return cost;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_CPU_OPERATOR(ReluN, ReluNOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(ReluNGradient, ReluNGradientOp<float, CPUContext>);
|
||||
|
||||
// Input: X, output: Y
|
||||
OPERATOR_SCHEMA(ReluN)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Arg("n", "the cap of output")
|
||||
.AllowInplace({{0, 0}})
|
||||
.CostInferenceFunction(CostInferenceForReluN)
|
||||
.IdenticalTypeAndShape()
|
||||
.SetDoc(R"DOC(
|
||||
Relu takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the rectified linear function, y = min(max(0, x), n),
|
||||
is applied to the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "1D input tensor")
|
||||
.Output(0, "Y", "1D input tensor");
|
||||
|
||||
// Input: Y, dY, output: dX
|
||||
OPERATOR_SCHEMA(ReluNGradient)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.Arg("n", "the cap of forward op output")
|
||||
.AllowInplace({{1, 0}})
|
||||
.SetDoc(R"DOC(
|
||||
ReluGradient takes both Y and dY and uses this to update dX according to the
|
||||
chain rule and derivatives of the rectified linear function.
|
||||
)DOC");
|
||||
|
||||
class GetReluNGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
return SingleGradientDef(
|
||||
def_.type() + "Gradient",
|
||||
"",
|
||||
vector<string>{O(0), GO(0)},
|
||||
vector<string>{GI(0)});
|
||||
}
|
||||
};
|
||||
REGISTER_GRADIENT(ReluN, GetReluNGradient);
|
||||
|
||||
} // namespace caffe2
|
82
caffe2/operators/relu_n_op.cu
Normal file
82
caffe2/operators/relu_n_op.cu
Normal file
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/relu_n_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
template <typename T>
|
||||
__global__ void ReluNKernel(const int N, const T* X, T* Y, const T thres) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
auto data = X[i];
|
||||
Y[i] = data > 0 ? (data > thres ? thres : data) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ReluNGradientKernel(
|
||||
const int N,
|
||||
const T* Y,
|
||||
const T* dY,
|
||||
T* dX,
|
||||
const T thres) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
auto data = Y[i];
|
||||
dX[i] = data > 0 ? (data >= thres ? 0 : dY[i]) : 0;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
bool ReluNOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
CAFFE_ENFORCE_GT(X.size(), 0);
|
||||
Y->ResizeLike(X);
|
||||
ReluNKernel<<<
|
||||
CAFFE_GET_BLOCKS(X.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
X.size(), X.data<float>(), Y->mutable_data<float>(), n);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ReluNGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& Y = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto* dX = Output(0);
|
||||
CAFFE_ENFORCE_GT(Y.size(), 0);
|
||||
CAFFE_ENFORCE_EQ(dY.size(), Y.size());
|
||||
dX->ResizeLike(Y);
|
||||
ReluNGradientKernel<float>
|
||||
<<<CAFFE_GET_BLOCKS(Y.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
Y.size(),
|
||||
Y.data<float>(),
|
||||
dY.data<float>(),
|
||||
dX->mutable_data<float>(),
|
||||
n);
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CUDA_OPERATOR(ReluN, ReluNOp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(ReluNGradient, ReluNGradientOp<float, CUDAContext>);
|
||||
} // namespace caffe2
|
62
caffe2/operators/relu_n_op.h
Normal file
62
caffe2/operators/relu_n_op.h
Normal file
@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CAFFE2_OPERATORS_RELU_N_OP_H_
|
||||
#define CAFFE2_OPERATORS_RELU_N_OP_H_
|
||||
|
||||
#include "caffe2/core/common_omp.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T, class Context>
|
||||
class ReluNOp final : public Operator<Context> {
|
||||
public:
|
||||
ReluNOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
n(OperatorBase::GetSingleArgument<float>("n", 6.0)) {
|
||||
CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0");
|
||||
}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
protected:
|
||||
float n;
|
||||
};
|
||||
|
||||
template <typename T, class Context>
|
||||
class ReluNGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
ReluNGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
n(OperatorBase::GetSingleArgument<float>("n", 6.0)) {
|
||||
CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0");
|
||||
}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
protected:
|
||||
// Input: Y, dY; Output: dX
|
||||
float n;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_RELU_N_OP_H_
|
@ -7,14 +7,28 @@ REGISTER_CPU_OPERATOR(Shape, ShapeOp<CPUContext>);
|
||||
OPERATOR_SCHEMA(Shape)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.TensorInferenceFunction([](const OperatorDef& /*def*/,
|
||||
.Arg(
|
||||
"axes",
|
||||
"(int[]) array of interested axes."
|
||||
"If given, this operators only returns the dimension of given axes."
|
||||
"Otherwise, the operator returns full dimension.")
|
||||
.TensorInferenceFunction([](const OperatorDef& def,
|
||||
const vector<TensorShape>& in) {
|
||||
ArgumentHelper args(def);
|
||||
const vector<int>& axes = args.GetRepeatedArgument<int>("axes");
|
||||
vector<TensorShape> out(1);
|
||||
out[0].add_dims(in[0].dims().size());
|
||||
if (axes.empty()) {
|
||||
out[0].add_dims(in[0].dims().size());
|
||||
} else {
|
||||
out[0].add_dims(axes.size());
|
||||
}
|
||||
out[0].set_data_type(TensorProto::INT32);
|
||||
return out;
|
||||
})
|
||||
.SetDoc("Produce a 1D int64 tensor with the shape of the input tensor.");
|
||||
.SetDoc(R"DOC(
|
||||
Produce a 1D int64 tensor with the shape of the input tensor.
|
||||
If called with an optional argument \"axes\", the result will only
|
||||
contain the dimension of specified axes in particular order.)DOC");
|
||||
|
||||
SHOULD_NOT_DO_GRADIENT(Shape);
|
||||
|
||||
|
@ -13,17 +13,41 @@ template <class Context>
|
||||
class ShapeOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
USE_SIMPLE_CTOR_DTOR(ShapeOp);
|
||||
ShapeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
axes_(OperatorBase ::GetRepeatedArgument<int>("axes")) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
auto& input = Input(0);
|
||||
auto& data = Input(DATA);
|
||||
auto* output = OperatorBase::Output<Tensor<Context>>(0);
|
||||
output->Resize(input.ndim());
|
||||
TIndex* output_data = output->template mutable_data<TIndex>();
|
||||
context_.template CopyBytes<Context, Context>(
|
||||
input.ndim() * sizeof(TIndex), input.dims().data(), output_data);
|
||||
int numDims = data.ndim();
|
||||
int numAxes = axes_.size();
|
||||
if (numAxes == 0) {
|
||||
output->Resize(numDims);
|
||||
TIndex* output_data = output->template mutable_data<TIndex>();
|
||||
context_.template CopyBytes<Context, Context>(
|
||||
numDims * sizeof(TIndex), data.dims().data(), output_data);
|
||||
return true;
|
||||
}
|
||||
|
||||
output->Resize(numAxes);
|
||||
auto src = reinterpret_cast<const char*>(data.dims().data());
|
||||
auto out = reinterpret_cast<char*>(output->template mutable_data<TIndex>());
|
||||
for (int i = 0; i < numAxes; i++) {
|
||||
auto axis = axes_[i];
|
||||
CAFFE_ENFORCE_LT(axis, numDims, "Axis out of range");
|
||||
CAFFE_ENFORCE_GE(axis, 0, "Each axis should be non-negative");
|
||||
context_.template CopyBytes<Context, Context>(
|
||||
sizeof(TIndex), src + axis * sizeof(TIndex), out);
|
||||
out += sizeof(TIndex);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
INPUT_TAGS(DATA);
|
||||
|
||||
private:
|
||||
vector<int> axes_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -26,10 +26,7 @@ REGISTER_CPU_OPERATOR(
|
||||
ScatterWeightedSum,
|
||||
ScatterWeightedSumOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
|
||||
// From whatever the current context, ensure the output is TensorCPU
|
||||
REGISTER_CPU_OPERATOR(
|
||||
EnsureCPUOutput,
|
||||
CopyOp<CPUContext, CPUContext, CPUContext>);
|
||||
|
||||
// From CPU, copy it to whatever the current context
|
||||
REGISTER_CPU_OPERATOR(
|
||||
CopyFromCPUInput,
|
||||
@ -300,26 +297,6 @@ Copy tensor for CPU to GPU context. Must be run under GPU device option.
|
||||
.Input(0, "input", "The input tensor.")
|
||||
.Output(0, "output", "Tensor that will contain a copy of the input.");
|
||||
|
||||
OPERATOR_SCHEMA(EnsureCPUOutput)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.IdenticalTypeAndShape()
|
||||
.InputsCanCrossDevices()
|
||||
.DeviceInferenceFunction([](const OperatorDef& def) {
|
||||
auto op_device =
|
||||
def.has_device_option() ? def.device_option() : DeviceOption();
|
||||
auto cpu_option = DeviceOption();
|
||||
vector<DeviceOption> in_dev(def.input_size(), op_device);
|
||||
vector<DeviceOption> out_dev(def.output_size(), cpu_option);
|
||||
return std::make_pair(in_dev, out_dev);
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
Take an input tensor in the current Context (GPU or CPU) and create an output
|
||||
which is always a TensorCPU. This may involves cross-device MemCpy.
|
||||
)DOC")
|
||||
.Input(0, "input", "The input CUDA or CPU tensor.")
|
||||
.Output(0, "output", "TensorCPU that is a copy of the input.");
|
||||
|
||||
OPERATOR_SCHEMA(CopyFromCPUInput)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
|
@ -58,10 +58,7 @@ REGISTER_CUDA_OPERATOR(Alias, AliasOp<CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(ResizeLike, ResizeLikeOp<CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(Sum, SumOp<CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<CUDAContext>);
|
||||
// From whatever the current context, ensure the output is TensorCPU
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
EnsureCPUOutput,
|
||||
CopyOp<CUDAContext, CPUContext, CUDAContext>);
|
||||
|
||||
// From CPU, copy it to whatever the current context
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
CopyFromCPUInput,
|
||||
|
@ -26,37 +26,6 @@ static void AddConstInput(
|
||||
return;
|
||||
}
|
||||
|
||||
TEST(UtilityOpGPUTest, testEnsureCPUOutput) {
|
||||
if (!HasCudaGPU())
|
||||
return;
|
||||
Workspace ws;
|
||||
OperatorDef def;
|
||||
def.set_name("test");
|
||||
def.set_type("EnsureCPUOutput");
|
||||
def.add_input("X");
|
||||
def.add_output("Y");
|
||||
def.mutable_device_option()->set_device_type(CUDA);
|
||||
AddConstInput(vector<TIndex>{5, 10}, 3.14, "X", &ws);
|
||||
Blob* Xblob = ws.GetBlob("X");
|
||||
EXPECT_NE(nullptr, Xblob);
|
||||
// input X should start as a CUDATensor
|
||||
EXPECT_TRUE(Xblob->IsType<Tensor<CUDAContext>>());
|
||||
// now execute the op to get Y
|
||||
unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
|
||||
EXPECT_NE(nullptr, op.get());
|
||||
EXPECT_TRUE(op->Run());
|
||||
Blob* Yblob = ws.GetBlob("Y");
|
||||
EXPECT_NE(nullptr, Yblob);
|
||||
// output Y should be a CPUTensor
|
||||
EXPECT_TRUE(Yblob->IsType<Tensor<CPUContext>>());
|
||||
const TensorCPU& Y_cpu = Yblob->Get<Tensor<CPUContext>>();
|
||||
EXPECT_EQ(Y_cpu.size(), 5 * 10);
|
||||
for (int i = 0; i < Y_cpu.size(); ++i) {
|
||||
EXPECT_LT(Y_cpu.data<float>()[i], 3.15);
|
||||
EXPECT_GT(Y_cpu.data<float>()[i], 3.13);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(UtilityOpGPUTest, testReshapeWithScalar) {
|
||||
if (!HasCudaGPU())
|
||||
return;
|
||||
|
@ -23,34 +23,6 @@ static void AddConstInput(
|
||||
return;
|
||||
}
|
||||
|
||||
TEST(UtilityOpTest, testEnsureCPUOutput) {
|
||||
Workspace ws;
|
||||
OperatorDef def;
|
||||
def.set_name("test");
|
||||
def.set_type("EnsureCPUOutput");
|
||||
def.add_input("X");
|
||||
def.add_output("Y");
|
||||
AddConstInput(vector<TIndex>{5, 10}, 3.14, "X", &ws);
|
||||
Blob* Xblob = ws.GetBlob("X");
|
||||
EXPECT_NE(nullptr, Xblob);
|
||||
// input X should be a CPUTensor
|
||||
EXPECT_TRUE(Xblob->IsType<Tensor<CPUContext>>());
|
||||
// now execute the op to get Y
|
||||
unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
|
||||
EXPECT_NE(nullptr, op.get());
|
||||
EXPECT_TRUE(op->Run());
|
||||
Blob* Yblob = ws.GetBlob("Y");
|
||||
EXPECT_NE(nullptr, Yblob);
|
||||
// output Y should be a CPUTensor
|
||||
EXPECT_TRUE(Yblob->IsType<Tensor<CPUContext>>());
|
||||
const TensorCPU& Y_cpu = Yblob->Get<Tensor<CPUContext>>();
|
||||
EXPECT_EQ(Y_cpu.size(), 5 * 10);
|
||||
for (int i = 0; i < Y_cpu.size(); ++i) {
|
||||
EXPECT_LT(Y_cpu.data<float>()[i], 3.15);
|
||||
EXPECT_GT(Y_cpu.data<float>()[i], 3.13);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(UtilityOpTest, testReshapeWithScalar) {
|
||||
Workspace ws;
|
||||
OperatorDef def;
|
||||
|
@ -8,89 +8,118 @@ from __future__ import unicode_literals
|
||||
import os
|
||||
|
||||
from caffe2.python import core
|
||||
from caffe2.python.dataio import Reader
|
||||
from caffe2.python.dataset import Dataset
|
||||
from caffe2.python.db_file_reader import DBFileReader
|
||||
from caffe2.python.pipeline import pipe
|
||||
from caffe2.python.task import Cluster, TaskGroup
|
||||
|
||||
|
||||
class CachedReader(Reader):
|
||||
"""
|
||||
Reader with persistent in-file cache.
|
||||
class CachedReader(DBFileReader):
|
||||
|
||||
default_name_suffix = 'cached_reader'
|
||||
|
||||
"""Reader with persistent in-file cache.
|
||||
|
||||
Example usage:
|
||||
cached_reader = CachedReader(reader)
|
||||
build_cache_step = cached_reader.build_cache('/tmp/cache.db')
|
||||
cached_reader = CachedReader(
|
||||
reader,
|
||||
db_path='/tmp/cache.db',
|
||||
db_type='LevelDB',
|
||||
)
|
||||
build_cache_step = cached_reader.build_cache_step()
|
||||
with LocalSession() as session:
|
||||
session.run(build_cache_step)
|
||||
|
||||
Every time new reader is created, it's expected that build_cache will be
|
||||
called before setup_ex and usage of the reader. build_cache will check
|
||||
existence of provided file path and in case it's missing will initialize it
|
||||
by reading data from original reader. All consequent attempts to read will
|
||||
ignore original reader (i.e. no additional data will be read from it).
|
||||
Every time new CachedReader is created, it's expected that
|
||||
db_path exists before calling .setup_ex(...) and .read(...).
|
||||
|
||||
If db_path doesn't exist, it's expected build_cache_step to be called
|
||||
first to build a cache at db_path.
|
||||
|
||||
build_cache_step will check existence of provided db_path and in case
|
||||
it's missing will initialize it by reading data from original reader.
|
||||
All consequent attempts to read will ignore original reader
|
||||
(i.e. no additional data will be read from it).
|
||||
|
||||
Args:
|
||||
original_reader: Reader.
|
||||
If provided, it's the original reader used to build the cache file.
|
||||
db_path: str.
|
||||
db_type: str. DB type of file. A db_type is registed by
|
||||
`REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
|
||||
Default to 'LevelDB'.
|
||||
name: str or None. Name of CachedReader.
|
||||
Optional name to prepend to blobs that will store the data.
|
||||
Default to '<db_name>_<default_name_suffix>'.
|
||||
batch_size: int.
|
||||
How many examples are read for each time the read_net is run.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
original_reader,
|
||||
db_path,
|
||||
db_type='LevelDB',
|
||||
name=None,
|
||||
batch_size=100,
|
||||
):
|
||||
assert original_reader is not None, "original_reader can't be None"
|
||||
self.original_reader = original_reader
|
||||
|
||||
def __init__(self, reader, db_type='leveldb', name='cached_reader'):
|
||||
super(CachedReader, self).__init__(reader.schema())
|
||||
self.original_reader = reader
|
||||
self.cache_path = None
|
||||
self.ds_reader = None
|
||||
self.ds = Dataset(self._schema, name)
|
||||
self.db_type = db_type
|
||||
self.name = name
|
||||
self.field_names = self._schema.field_names()
|
||||
super(CachedReader, self).__init__(
|
||||
db_path,
|
||||
db_type,
|
||||
name,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
def setup_ex(self, init_net, finish_net):
|
||||
assert self.cache_path, 'build_cache must be called first'
|
||||
self._init_dataset(init_net)
|
||||
self._load_from_file(init_net)
|
||||
self.ds_reader = self.ds.reader(init_net, batch_size=100)
|
||||
def _init_reader_schema(self):
|
||||
"""Prepare the reader schema.
|
||||
|
||||
def read(self, read_net):
|
||||
assert self.ds_reader, 'setup must be called first'
|
||||
return self.ds_reader.read(read_net)
|
||||
Since an original reader is given,
|
||||
use it's schema as ground truth.
|
||||
|
||||
def has_cache(self):
|
||||
return self.cache_path and os.path.exists(self.cache_path)
|
||||
Returns:
|
||||
schema: schema.Struct. Used in Reader.__init__(...).
|
||||
"""
|
||||
return self.original_reader._schema
|
||||
|
||||
def build_cache(self, cache_path, overwrite=False):
|
||||
if not self.has_cache() or overwrite:
|
||||
self.cache_path = cache_path
|
||||
if self.has_cache() and not overwrite:
|
||||
def build_cache_step(self, overwrite=False):
|
||||
"""Build a step for generating cache DB file.
|
||||
|
||||
If self.db_path exists and not overwritting, build an empty step.
|
||||
Overwise, build a step as follows.
|
||||
Pipe original reader to the _DatasetWriter,
|
||||
so that dataset field blobs are populated.
|
||||
Then save these blobs into a file.
|
||||
|
||||
Args:
|
||||
overwrite: bool. If true, ignore the existing file
|
||||
and build a new one overwritting the existing one anyway.
|
||||
|
||||
Returns:
|
||||
build_cache_step: ExcutionStep.
|
||||
The step to be run for building a cache DB file.
|
||||
"""
|
||||
if os.path.exists(self.db_path) and not overwrite:
|
||||
# cache already exists, no need to rebuild it
|
||||
return core.execution_step('build_step', [])
|
||||
|
||||
init_net = core.Net('init')
|
||||
self._init_dataset(init_net)
|
||||
self._init_field_blobs_as_empty(init_net)
|
||||
with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg:
|
||||
pipe(self.original_reader, self.ds.writer(), num_threads=16)
|
||||
copy_step = copy_tg.to_task().get_step()
|
||||
save_net = core.Net('save')
|
||||
self._save_to_file(save_net)
|
||||
self._save_field_blobs_to_db_file(save_net)
|
||||
|
||||
return core.execution_step('build_cache', [init_net, copy_step, save_net])
|
||||
|
||||
def _init_dataset(self, init_net):
|
||||
with core.NameScope(self.name):
|
||||
self.ds.init_empty(init_net)
|
||||
|
||||
def _save_to_file(self, net):
|
||||
def _save_field_blobs_to_db_file(self, net):
|
||||
"""Save dataset field blobs to a DB file at db_path"""
|
||||
net.Save(
|
||||
self.ds.content().field_blobs(),
|
||||
self.ds.get_blobs(),
|
||||
[],
|
||||
db=self.cache_path,
|
||||
db=self.db_path,
|
||||
db_type=self.db_type,
|
||||
blob_name_overrides=self.field_names,
|
||||
blob_name_overrides=self.ds.field_names(),
|
||||
absolute_path=True,
|
||||
)
|
||||
|
||||
def _load_from_file(self, net):
|
||||
net.Load(
|
||||
[],
|
||||
self.ds.content().field_blobs(),
|
||||
db=self.cache_path,
|
||||
db_type=self.db_type,
|
||||
absolute_path=True,
|
||||
source_blob_names=self.field_names,
|
||||
)
|
||||
|
@ -122,6 +122,19 @@ def InferBlobDevices(net):
|
||||
return mapping
|
||||
|
||||
|
||||
def InferOpBlobDevicesAsDict(op):
|
||||
input_dev_list, output_dev_list = InferOpBlobDevices(op)
|
||||
input_dict = {
|
||||
op.input[i]: input_dev_list[i]
|
||||
for i in range(len(op.input))
|
||||
}
|
||||
output_dict = {
|
||||
op.output[i]: output_dev_list[i]
|
||||
for i in range(len(op.output))
|
||||
}
|
||||
return input_dict, output_dict
|
||||
|
||||
|
||||
def InferOpBlobDevices(op):
|
||||
device_info = C.infer_op_input_output_device(op.SerializeToString())
|
||||
input_info = []
|
||||
|
@ -613,10 +613,24 @@ class TestInferDevice(test_util.TestCase):
|
||||
with core.DeviceScope(op_option):
|
||||
op = core.CreateOperator(op_name, inputs, outputs)
|
||||
input_dev, output_dev = core.InferOpBlobDevices(op)
|
||||
for in_dev in input_dev:
|
||||
self.assertEqual(in_dev, in_option)
|
||||
for out_dev in output_dev:
|
||||
self.assertEqual(out_dev, out_option)
|
||||
if isinstance(in_option, list):
|
||||
assert len(in_option) == len(input_dev), \
|
||||
'Length of input device option should match' \
|
||||
'{} vs. {}'.format(in_option, input_dev)
|
||||
for in_dev, in_opt in zip(input_dev, in_option):
|
||||
self.assertEqual(in_dev, in_opt)
|
||||
else:
|
||||
for in_dev in input_dev:
|
||||
self.assertEqual(in_dev, in_option)
|
||||
if isinstance(out_option, list):
|
||||
assert len(out_option) == len(output_dev), \
|
||||
'Length of output device option should match' \
|
||||
'{} vs. {}'.format(out_option, output_dev)
|
||||
for out_dev, out_opt in zip(output_dev, out_option):
|
||||
self.assertEqual(out_dev, out_opt)
|
||||
else:
|
||||
for out_dev in output_dev:
|
||||
self.assertEqual(out_dev, out_option)
|
||||
|
||||
def test_infer_device(self):
|
||||
self._test_op(
|
||||
@ -628,6 +642,16 @@ class TestInferDevice(test_util.TestCase):
|
||||
outputs=["fc_1"]
|
||||
)
|
||||
|
||||
def test_infer_device_split_by_lengths(self):
|
||||
self._test_op(
|
||||
"SplitByLengths",
|
||||
[self.cuda_option, self.cpu_option],
|
||||
self.cuda_option,
|
||||
op_option=self.cuda_option,
|
||||
inputs=["data", "fc_w"],
|
||||
outputs=["fc_1"]
|
||||
)
|
||||
|
||||
def test_infer_device_cross_device(self):
|
||||
self._test_op("CopyGPUToCPU", self.cuda_option, self.cpu_option)
|
||||
self._test_op("CopyCPUToGPU", self.cpu_option, self.cuda_option)
|
||||
|
@ -14,6 +14,7 @@ from caffe2.python import \
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/nccl:nccl_ops")
|
||||
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/gloo:gloo_ops")
|
||||
@ -717,23 +718,23 @@ def _AddBarrierToModelNets(model, barrier_net_timeout_sec):
|
||||
# shards that are faster than others will begin training the next epoch
|
||||
# while stragglers are blocked on IO, and may timeout after 30 seconds
|
||||
# (_DEFAULT_TIMEOUT_SEC).
|
||||
model._barrier_net = _CreateBarrierNet(model,
|
||||
# We pass in model.param_init_net so that the barrier net can be run as
|
||||
# part of the param_init_net.
|
||||
model._barrier_net = _CreateBarrierNet(model, model.param_init_net,
|
||||
"pre_training", barrier_net_timeout_sec)
|
||||
model._data_parallel_model_nets.insert(0, model._barrier_net)
|
||||
|
||||
|
||||
def _CreateBarrierNet(model, name_prefix, timeout_sec):
|
||||
def _CreateBarrierNet(model, init_net, name_prefix, timeout_sec):
|
||||
log.info("Creating barrier net")
|
||||
assert model._rendezvous['engine'] == 'GLOO', "Engine does not support barrier"
|
||||
barrier_init_net = core.Net(name_prefix + "_barrier_init_net")
|
||||
comm_world = _CreateOrCloneCommonWorld(
|
||||
barrier_init_net,
|
||||
init_net,
|
||||
name_prefix + "_barrier_cw",
|
||||
rendezvous=model._rendezvous,
|
||||
status_blob=name_prefix + "_barrier_cw_status",
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
workspace.RunNetOnce(barrier_init_net)
|
||||
barrier_net = core.Net(name_prefix + "_barrier_net")
|
||||
barrier_net.Barrier(
|
||||
inputs=[comm_world],
|
||||
@ -744,13 +745,23 @@ def _CreateBarrierNet(model, name_prefix, timeout_sec):
|
||||
return barrier_net
|
||||
|
||||
|
||||
# DEPRECATED: See warnings below.
|
||||
def Synchronize(model, timeout_sec=_DEFAULT_BARRIER_NET_TIMEOUT_SEC):
|
||||
warnings.warn("The Synchronize API has been deprecated. We now have a "
|
||||
"barrier net which runs before training to ensure all hosts wait "
|
||||
"before training starts. The default timeout for the barrier is "
|
||||
"300s and it can be overridden using the barrier_net_timeout_sec "
|
||||
"parameter when calling Parallelize.",
|
||||
category=DeprecationWarning, stacklevel=2)
|
||||
if model._rendezvous is None or model._rendezvous['num_shards'] <= 1:
|
||||
# Single host case
|
||||
return
|
||||
|
||||
if model._sync_barrier_net is None:
|
||||
model._sync_barrier_net = _CreateBarrierNet(model, "sync", timeout_sec)
|
||||
barrier_init_net = core.Net("sync_barrier_init_net")
|
||||
model._sync_barrier_net = _CreateBarrierNet(
|
||||
model, barrier_init_net, "sync", timeout_sec)
|
||||
workspace.RunNetOnce(barrier_init_net)
|
||||
workspace.CreateNet(model._sync_barrier_net)
|
||||
model._sync_barrier_net_timeout = timeout_sec
|
||||
assert model._sync_barrier_net_timeout == timeout_sec, \
|
||||
|
@ -384,7 +384,7 @@ class DataParallelModelTest(TestCase):
|
||||
# Set network timeout to 2 seconds, and add a 3 seconds
|
||||
# sleep for 1 host. Make sure there is no timeout on the
|
||||
# second RunNet.
|
||||
data_parallel_model._DEFAULT_TIMEOUT_SEC=2
|
||||
data_parallel_model._DEFAULT_TIMEOUT_SEC = 2
|
||||
data_parallel_model.Parallelize_CPU(
|
||||
model,
|
||||
input_builder_fun=add_input_ops,
|
||||
|
@ -12,6 +12,7 @@ from caffe2.python.dataio import (
|
||||
ReaderWithTimeLimit,
|
||||
)
|
||||
from caffe2.python.dataset import Dataset
|
||||
from caffe2.python.db_file_reader import DBFileReader
|
||||
from caffe2.python.pipeline import pipe
|
||||
from caffe2.python.schema import Struct, NewRecord, FeedRecord
|
||||
from caffe2.python.session import LocalSession
|
||||
@ -52,16 +53,6 @@ def make_destination_dataset(ws, schema, name=None):
|
||||
return dst_ds
|
||||
|
||||
|
||||
def read_all_data(ws, reader, session):
|
||||
dst_ds = make_destination_dataset(ws, reader.schema().clone_schema())
|
||||
|
||||
with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
|
||||
pipe(reader, dst_ds.writer(), num_runtime_threads=8)
|
||||
session.run(tg)
|
||||
|
||||
return ws.blobs[str(dst_ds.content().label())].fetch()
|
||||
|
||||
|
||||
class ReaderWithDelay(Reader):
|
||||
"""Test reader class that inserts a delay between reading batches."""
|
||||
def __init__(self, reader, delay):
|
||||
@ -345,46 +336,107 @@ class TestReaderWithLimit(TestCase):
|
||||
read_delay=0.25,
|
||||
duration=6)
|
||||
|
||||
|
||||
class TestDBFileReader(TestCase):
|
||||
def setUp(self):
|
||||
self.temp_paths = []
|
||||
|
||||
def tearDown(self):
|
||||
# In case any test method fails, clean up temp paths.
|
||||
for path in self.temp_paths:
|
||||
self._delete_path(path)
|
||||
|
||||
@staticmethod
|
||||
def _delete_path(path):
|
||||
if os.path.isfile(path):
|
||||
os.remove(path) # Remove file.
|
||||
elif os.path.isdir(path):
|
||||
shutil.rmtree(path) # Remove dir recursively.
|
||||
|
||||
def _make_temp_path(self):
|
||||
# Make a temp path as db_path.
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
temp_path = f.name
|
||||
self.temp_paths.append(temp_path)
|
||||
return temp_path
|
||||
|
||||
@staticmethod
|
||||
def _build_source_reader(ws, size):
|
||||
src_ds = make_source_dataset(ws, size)
|
||||
return src_ds.reader()
|
||||
|
||||
@staticmethod
|
||||
def _read_all_data(ws, reader, session):
|
||||
dst_ds = make_destination_dataset(ws, reader.schema().clone_schema())
|
||||
|
||||
with TaskGroup() as tg:
|
||||
pipe(reader, dst_ds.writer(), num_runtime_threads=8)
|
||||
session.run(tg)
|
||||
|
||||
return ws.blobs[str(dst_ds.content().label())].fetch()
|
||||
|
||||
def test_cached_reader(self):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
|
||||
def build_source_reader(size):
|
||||
src_ds = make_source_dataset(ws, size)
|
||||
return src_ds.reader()
|
||||
|
||||
# Make a temp file path as cache_path
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
cache_path = f.name
|
||||
f.close()
|
||||
os.remove(cache_path)
|
||||
db_path = self._make_temp_path()
|
||||
|
||||
# Read data for the first time.
|
||||
cached_reader1 = CachedReader(build_source_reader(100))
|
||||
init_step = cached_reader1.build_cache(cache_path)
|
||||
session.run(init_step)
|
||||
cached_reader1 = CachedReader(
|
||||
self._build_source_reader(ws, 100), db_path,
|
||||
)
|
||||
build_cache_step = cached_reader1.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
data = read_all_data(ws, cached_reader1, session)
|
||||
data = self._read_all_data(ws, cached_reader1, session)
|
||||
self.assertEqual(sorted(data), list(range(100)))
|
||||
|
||||
# Read data from cache.
|
||||
workspace.ResetWorkspace()
|
||||
cached_reader2 = CachedReader(build_source_reader(200))
|
||||
init_step = cached_reader2.build_cache(cache_path)
|
||||
session.run(init_step)
|
||||
cached_reader2 = CachedReader(
|
||||
self._build_source_reader(ws, 200), db_path,
|
||||
)
|
||||
build_cache_step = cached_reader2.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
data = read_all_data(ws, cached_reader2, session)
|
||||
data = self._read_all_data(ws, cached_reader2, session)
|
||||
self.assertEqual(sorted(data), list(range(100)))
|
||||
|
||||
shutil.rmtree(cache_path)
|
||||
self._delete_path(db_path)
|
||||
|
||||
# We removed cache so we expect to receive data from original reader
|
||||
# We removed cache so we expect to receive data from original reader.
|
||||
workspace.ResetWorkspace()
|
||||
cached_reader3 = CachedReader(build_source_reader(300))
|
||||
init_step = cached_reader3.build_cache(cache_path)
|
||||
session.run(init_step)
|
||||
cached_reader3 = CachedReader(
|
||||
self._build_source_reader(ws, 300), db_path,
|
||||
)
|
||||
build_cache_step = cached_reader3.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
data = read_all_data(ws, cached_reader3, session)
|
||||
data = self._read_all_data(ws, cached_reader3, session)
|
||||
self.assertEqual(sorted(data), list(range(300)))
|
||||
|
||||
shutil.rmtree(cache_path)
|
||||
self._delete_path(db_path)
|
||||
|
||||
def test_db_file_reader(self):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
db_path = self._make_temp_path()
|
||||
|
||||
# Build a cache DB file.
|
||||
cached_reader = CachedReader(
|
||||
self._build_source_reader(ws, 100),
|
||||
db_path=db_path,
|
||||
db_type='LevelDB',
|
||||
)
|
||||
build_cache_step = cached_reader.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
# Read data from cache DB file.
|
||||
workspace.ResetWorkspace()
|
||||
db_file_reader = DBFileReader(
|
||||
db_path=db_path,
|
||||
db_type='LevelDB',
|
||||
)
|
||||
data = self._read_all_data(ws, db_file_reader, session)
|
||||
self.assertEqual(sorted(data), list(range(100)))
|
||||
|
||||
self._delete_path(db_path)
|
||||
|
157
caffe2/python/db_file_reader.py
Normal file
157
caffe2/python/db_file_reader.py
Normal file
@ -0,0 +1,157 @@
|
||||
## @package db_file_reader
|
||||
# Module caffe2.python.db_file_reader
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core, scope, workspace, _import_c_extension as C
|
||||
from caffe2.python.dataio import Reader
|
||||
from caffe2.python.dataset import Dataset
|
||||
from caffe2.python.schema import from_column_list
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class DBFileReader(Reader):
|
||||
|
||||
default_name_suffix = 'db_file_reader'
|
||||
|
||||
"""Reader reads from a DB file.
|
||||
|
||||
Example usage:
|
||||
db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB')
|
||||
|
||||
Args:
|
||||
db_path: str.
|
||||
db_type: str. DB type of file. A db_type is registed by
|
||||
`REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
|
||||
name: str or None. Name of DBFileReader.
|
||||
Optional name to prepend to blobs that will store the data.
|
||||
Default to '<db_name>_<default_name_suffix>'.
|
||||
batch_size: int.
|
||||
How many examples are read for each time the read_net is run.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
db_path,
|
||||
db_type,
|
||||
name=None,
|
||||
batch_size=100,
|
||||
):
|
||||
assert db_path is not None, "db_path can't be None."
|
||||
assert db_type in C.registered_dbs(), \
|
||||
"db_type [{db_type}] is not available. \n" \
|
||||
"Choose one of these: {registered_dbs}.".format(
|
||||
db_type=db_type,
|
||||
registered_dbs=C.registered_dbs(),
|
||||
)
|
||||
|
||||
self.db_path = db_path
|
||||
self.db_type = db_type
|
||||
self.name = name or '{db_name}_{default_name_suffix}'.format(
|
||||
db_name=self._extract_db_name_from_db_path(),
|
||||
default_name_suffix=self.default_name_suffix,
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
|
||||
# Before self._init_reader_schema(...),
|
||||
# self.db_path and self.db_type are required to be set.
|
||||
super(DBFileReader, self).__init__(self._init_reader_schema())
|
||||
self.ds = Dataset(self._schema, self.name + '_dataset')
|
||||
self.ds_reader = None
|
||||
|
||||
def _init_name(self, name):
|
||||
return name or self._extract_db_name_from_db_path(
|
||||
) + '_db_file_reader'
|
||||
|
||||
def _init_reader_schema(self):
|
||||
"""Restore a reader schema from the DB file.
|
||||
|
||||
Here it is assumed that:
|
||||
1). Each field of the schema have corresponding blobs
|
||||
stored in the DB file.
|
||||
2). Each blob loaded from the DB file corresponds to
|
||||
a field of the schema.
|
||||
|
||||
Load a set of blobs from a DB file. From names of these blobs,
|
||||
restore the DB file schema using `from_column_list(...)`.
|
||||
|
||||
Returns:
|
||||
schema: schema.Struct. Used in Reader.__init__(...).
|
||||
"""
|
||||
assert os.path.exists(self.db_path), \
|
||||
'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
|
||||
with core.NameScope(self.name):
|
||||
# blob_prefix is for avoiding name conflict in workspace
|
||||
blob_prefix = scope.CurrentNameScope()
|
||||
workspace.RunOperatorOnce(
|
||||
core.CreateOperator(
|
||||
'Load',
|
||||
[],
|
||||
[],
|
||||
absolute_path=True,
|
||||
db=self.db_path,
|
||||
db_type=self.db_type,
|
||||
load_all=True,
|
||||
add_prefix=blob_prefix,
|
||||
)
|
||||
)
|
||||
col_names = [
|
||||
blob_name[len(blob_prefix):] for blob_name in workspace.Blobs()
|
||||
if blob_name.startswith(blob_prefix)
|
||||
]
|
||||
schema = from_column_list(col_names)
|
||||
return schema
|
||||
|
||||
def setup_ex(self, init_net, finish_net):
|
||||
"""From the Dataset, create a _DatasetReader and setup a init_net.
|
||||
|
||||
Make sure the _init_field_blobs_as_empty(...) is only called once.
|
||||
|
||||
Because the underlying NewRecord(...) creats blobs by calling
|
||||
NextScopedBlob(...), so that references to previously-initiated
|
||||
empty blobs will be lost, causing accessibility issue.
|
||||
"""
|
||||
if self.ds_reader:
|
||||
self.ds_reader.setup_ex(init_net, finish_net)
|
||||
else:
|
||||
self._init_field_blobs_as_empty(init_net)
|
||||
self._feed_field_blobs_from_db_file(init_net)
|
||||
self.ds_reader = self.ds.reader(
|
||||
init_net,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
def read(self, read_net):
|
||||
assert self.ds_reader, 'setup_ex must be called first'
|
||||
return self.ds_reader.read(read_net)
|
||||
|
||||
def _init_field_blobs_as_empty(self, init_net):
|
||||
"""Initialize dataset field blobs by creating an empty record"""
|
||||
with core.NameScope(self.name):
|
||||
self.ds.init_empty(init_net)
|
||||
|
||||
def _feed_field_blobs_from_db_file(self, net):
|
||||
"""Load from the DB file at db_path and feed dataset field blobs"""
|
||||
assert os.path.exists(self.db_path), \
|
||||
'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
|
||||
net.Load(
|
||||
[],
|
||||
self.ds.get_blobs(),
|
||||
db=self.db_path,
|
||||
db_type=self.db_type,
|
||||
absolute_path=True,
|
||||
source_blob_names=self.ds.field_names(),
|
||||
)
|
||||
|
||||
def _extract_db_name_from_db_path(self):
|
||||
"""Extract DB name from DB path
|
||||
|
||||
E.g. given self.db_path=`/tmp/sample.db`,
|
||||
it returns `sample`.
|
||||
|
||||
Returns:
|
||||
db_name: str.
|
||||
"""
|
||||
return os.path.basename(self.db_path).rsplit('.', 1)[0]
|
@ -3,6 +3,7 @@
|
||||
import numpy as np
|
||||
import copy
|
||||
from caffe2.python import workspace
|
||||
from caffe2.python.core import InferOpBlobDevicesAsDict
|
||||
from future.utils import viewitems
|
||||
|
||||
|
||||
@ -31,17 +32,20 @@ class DeviceChecker(object):
|
||||
boolean: True if it passes, False if it does not pass.
|
||||
"""
|
||||
op = copy.deepcopy(op)
|
||||
input_device_options = input_device_options or {}
|
||||
# Entering the checker workspace
|
||||
old_ws_name = workspace.CurrentWorkspace()
|
||||
results = []
|
||||
workspace.SwitchWorkspace("_device_check_", True)
|
||||
for i, device_option in enumerate(self._device_options):
|
||||
op.device_option.CopyFrom(device_option)
|
||||
_input_device_options = input_device_options or \
|
||||
InferOpBlobDevicesAsDict(op)[0]
|
||||
print(_input_device_options)
|
||||
for i, arr in enumerate(inputs):
|
||||
workspace.FeedBlob(
|
||||
op.input[i], np.array(arr),
|
||||
input_device_options.get(op.input[i], device_option))
|
||||
op.device_option.CopyFrom(device_option)
|
||||
_input_device_options.get(op.input[i], device_option)
|
||||
)
|
||||
workspace.RunOperatorOnce(op)
|
||||
results.append(
|
||||
[workspace.FetchBlob(op.output[idx])
|
||||
|
@ -158,12 +158,12 @@ class GradientChecker:
|
||||
self,
|
||||
stepsize,
|
||||
threshold,
|
||||
device_option=caffe2_pb2.DeviceOption(),
|
||||
device_option=None,
|
||||
workspace_name="gradient_check"
|
||||
):
|
||||
self._stepsize = stepsize
|
||||
self._threshold = threshold
|
||||
self._device_option = device_option
|
||||
self._device_option = device_option or caffe2_pb2.DeviceOption()
|
||||
self._workspace_name = workspace_name
|
||||
|
||||
def GetLossAndGrad(
|
||||
@ -239,8 +239,6 @@ class GradientChecker:
|
||||
Outputs:
|
||||
boolean: True if it passes, False if it does not pass.
|
||||
"""
|
||||
if input_device_options is None:
|
||||
input_device_options = {}
|
||||
# Entering the checker workspace
|
||||
old_ws_name = workspace.CurrentWorkspace()
|
||||
if self._workspace_name != old_ws_name:
|
||||
@ -254,11 +252,13 @@ class GradientChecker:
|
||||
op, [s + '_grad' for s in op.output])
|
||||
|
||||
dims_to_check = inputs[input_to_check].size
|
||||
_input_device_options = input_device_options or \
|
||||
core.InferOpBlobDevicesAsDict(op)[0]
|
||||
# First, feed in the input.
|
||||
for i, arr in enumerate(inputs):
|
||||
workspace.FeedBlob(
|
||||
op.input[i], arr,
|
||||
input_device_options.get(
|
||||
_input_device_options.get(
|
||||
op.input[i], self._device_option))
|
||||
|
||||
# Get the loss and gradient for the original.
|
||||
|
@ -1553,6 +1553,14 @@ class TestOperators(hu.HypothesisTestCase):
|
||||
op = core.CreateOperator("Shape", ["data"], ["shape"])
|
||||
self.assertReferenceChecks(gc, op, [data], lambda x: (x.shape, ))
|
||||
|
||||
@given(data=hu.tensor(), **hu.gcs_cpu_only)
|
||||
def test_shape_with_axes(self, data, gc, dc):
|
||||
def shape_ref(x, y):
|
||||
return ([x.shape[i] for i in y],)
|
||||
axes = np.random.randint(len(data.shape), size=10).tolist()
|
||||
op = core.CreateOperator("Shape", ["data"], ["shape"], axes=axes)
|
||||
self.assertReferenceChecks(gc, op, [data, axes], shape_ref)
|
||||
|
||||
@given(data=hu.tensor(), **hu.gcs_cpu_only)
|
||||
def test_has_elements(self, data, gc, dc):
|
||||
op = core.CreateOperator("HasElements", ["data"], ["has_elements"])
|
||||
|
@ -292,8 +292,6 @@ def runOpBenchmark(
|
||||
input_device_options=None,
|
||||
iterations=10,
|
||||
):
|
||||
if input_device_options is None:
|
||||
input_device_options = {}
|
||||
op = copy.deepcopy(op)
|
||||
op.device_option.CopyFrom(device_option)
|
||||
net = caffe2_pb2.NetDef()
|
||||
@ -301,11 +299,13 @@ def runOpBenchmark(
|
||||
net.name = op.name if op.name else "test"
|
||||
|
||||
with temp_workspace():
|
||||
_input_device_options = input_device_options or \
|
||||
core.InferOpBlobDevicesAsDict(op)[0]
|
||||
for (n, b) in zip(op.input, inputs):
|
||||
workspace.FeedBlob(
|
||||
n,
|
||||
b,
|
||||
device_option=input_device_options.get(n, device_option)
|
||||
device_option=_input_device_options.get(n, device_option)
|
||||
)
|
||||
workspace.CreateNet(net)
|
||||
ret = workspace.BenchmarkNet(net.name, 1, iterations, True)
|
||||
@ -519,9 +519,6 @@ class HypothesisTestCase(test_util.TestCase):
|
||||
|
||||
self.assertReferenceChecks(gc, op, [X], softsign)
|
||||
"""
|
||||
if input_device_options is None:
|
||||
input_device_options = {}
|
||||
|
||||
op = copy.deepcopy(op)
|
||||
op.device_option.CopyFrom(device_option)
|
||||
|
||||
@ -530,11 +527,13 @@ class HypothesisTestCase(test_util.TestCase):
|
||||
raise ValueError(
|
||||
'must supply an input for each input on the op: %s vs %s' %
|
||||
(op.input, inputs))
|
||||
_input_device_options = input_device_options or \
|
||||
core.InferOpBlobDevicesAsDict(op)[0]
|
||||
for (n, b) in zip(op.input, inputs):
|
||||
workspace.FeedBlob(
|
||||
n,
|
||||
b,
|
||||
device_option=input_device_options.get(n, device_option)
|
||||
device_option=_input_device_options.get(n, device_option)
|
||||
)
|
||||
net = core.Net("opnet")
|
||||
net.Proto().op.extend([op])
|
||||
@ -600,8 +599,6 @@ class HypothesisTestCase(test_util.TestCase):
|
||||
as_kwargs=True,
|
||||
init_net=None,
|
||||
):
|
||||
if input_device_options is None:
|
||||
input_device_options = {}
|
||||
if as_kwargs:
|
||||
assert len(set(list(op.input) + list(op.output))) == \
|
||||
len(op.input) + len(op.output), \
|
||||
@ -610,11 +607,13 @@ class HypothesisTestCase(test_util.TestCase):
|
||||
op.device_option.CopyFrom(device_option)
|
||||
|
||||
with temp_workspace():
|
||||
_input_device_options = input_device_options or \
|
||||
core.InferOpBlobDevicesAsDict(op)[0]
|
||||
for (n, b) in zip(op.input, inputs):
|
||||
workspace.FeedBlob(
|
||||
n,
|
||||
b,
|
||||
device_option=input_device_options.get(n, device_option)
|
||||
device_option=_input_device_options.get(n, device_option)
|
||||
)
|
||||
if init_net:
|
||||
workspace.RunNetOnce(init_net)
|
||||
@ -635,18 +634,17 @@ class HypothesisTestCase(test_util.TestCase):
|
||||
exception=(Exception,),
|
||||
regexp=None,
|
||||
):
|
||||
if input_device_options is None:
|
||||
input_device_options = {}
|
||||
|
||||
op = copy.deepcopy(op)
|
||||
op.device_option.CopyFrom(device_option)
|
||||
|
||||
with temp_workspace():
|
||||
_input_device_options = input_device_options or \
|
||||
core.InferOpBlobDevicesAsDict(op)[0]
|
||||
for (n, b) in zip(op.input, inputs):
|
||||
workspace.FeedBlob(
|
||||
n,
|
||||
b,
|
||||
device_option=input_device_options.get(n, device_option)
|
||||
device_option=_input_device_options.get(n, device_option)
|
||||
)
|
||||
if regexp is None:
|
||||
self.assertRaises(exception, workspace.RunOperatorOnce, op)
|
||||
|
@ -52,6 +52,8 @@ class SparseFeatureHash(ModelLayer):
|
||||
assert False, "Input type must be one of (IdList, IdScoreList)"
|
||||
|
||||
assert self.modulo >= 1, 'Unexpected modulo: {}'.format(self.modulo)
|
||||
if input_record.lengths.metadata:
|
||||
self.output_schema.lengths.set_metadata(input_record.lengths.metadata)
|
||||
|
||||
# operators in this layer do not have CUDA implementation yet.
|
||||
# In addition, since the sparse feature keys that we are hashing are
|
||||
|
@ -30,7 +30,6 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid.
|
||||
'|test_reduce_mean_cuda.*' # Does not support ReduceMean CUDA.
|
||||
'|test_reduce_prod.*' # Does not support ReduceProd.
|
||||
'|test_reduce_sum.*' # Does not support ReduceSum and ReduceSumSquare
|
||||
'|test_reduce_log_sum.*' # Does not support ReduceLogSum
|
||||
'|test_tile.*' # Tile's Caffe2 implementation needs some tweak
|
||||
'|test_lstm.*' # Seems LSTM case has some problem
|
||||
'|test_simple_rnn.*' # Seems simple RNN case has some problem
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
import hypothesis.strategies as st
|
||||
import unittest
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import core
|
||||
from hypothesis import given
|
||||
|
||||
@ -126,6 +127,54 @@ class TestConcatSplitOps(hu.HypothesisTestCase):
|
||||
self.assertDeviceChecks(dc, op, input_tensors, outputs_with_grad)
|
||||
self.assertGradientChecks(gc, op, input_tensors, 0, outputs_with_grad)
|
||||
|
||||
@given(
|
||||
inputs=hu.lengths_tensor(
|
||||
dtype=np.float32,
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
allow_empty=True,
|
||||
),
|
||||
**hu.gcs
|
||||
)
|
||||
def test_split_by_lengths(self, inputs, gc, dc):
|
||||
data, lengths = inputs
|
||||
len_len = len(lengths)
|
||||
|
||||
def _find_factor_simple(x):
|
||||
for i in [2, 3, 5]:
|
||||
if x % i == 0:
|
||||
return i
|
||||
return x
|
||||
|
||||
num_output = _find_factor_simple(len_len)
|
||||
axis = 0
|
||||
op = core.CreateOperator(
|
||||
"SplitByLengths",
|
||||
["data", "lengths"],
|
||||
['X_{}'.format(i) for i in range(num_output)],
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
def split_by_lengths_ref(data, lengths, num_output=num_output, axis=0):
|
||||
idxs = np.cumsum([0] + list(lengths)).astype(np.int32)
|
||||
return [
|
||||
np.array(
|
||||
data.take(
|
||||
np.arange(
|
||||
idxs[i * len_len // num_output],
|
||||
idxs[(i + 1) * len_len // num_output]
|
||||
),
|
||||
axis=axis
|
||||
)
|
||||
) for i in range(num_output)
|
||||
]
|
||||
outputs_with_grad = range(num_output)
|
||||
input_tensors = [data, lengths]
|
||||
self.assertReferenceChecks(gc, op, input_tensors, split_by_lengths_ref)
|
||||
self.assertDeviceChecks(dc, op, input_tensors, outputs_with_grad)
|
||||
self.assertGradientChecks(gc, op, input_tensors, 0, outputs_with_grad)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
51
caffe2/python/operator_test/ensure_cpu_output_op_test.py
Normal file
51
caffe2/python/operator_test/ensure_cpu_output_op_test.py
Normal file
@ -0,0 +1,51 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from hypothesis import given
|
||||
import numpy as np
|
||||
import hypothesis.strategies as st
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
|
||||
|
||||
@st.composite
|
||||
def _dev_options(draw):
|
||||
op_dev = draw(st.sampled_from(hu.device_options))
|
||||
if op_dev == hu.cpu_do:
|
||||
# the CPU op can only handle CPU tensor
|
||||
input_blob_dev = hu.cpu_do
|
||||
else:
|
||||
input_blob_dev = draw(st.sampled_from(hu.device_options))
|
||||
|
||||
return op_dev, input_blob_dev
|
||||
|
||||
|
||||
class TestEnsureCPUOutputOp(hu.HypothesisTestCase):
|
||||
|
||||
@given(
|
||||
input=hu.tensor(dtype=np.float32),
|
||||
dev_options=_dev_options()
|
||||
)
|
||||
def test_ensure_cpu_output(self, input, dev_options):
|
||||
op_dev, input_blob_dev = dev_options
|
||||
net = core.Net('test_net')
|
||||
data = net.GivenTensorFill(
|
||||
[],
|
||||
["data"],
|
||||
values=input,
|
||||
shape=input.shape,
|
||||
device_option=input_blob_dev
|
||||
)
|
||||
|
||||
data_cpu = net.EnsureCPUOutput(
|
||||
[data],
|
||||
["data_cpu"],
|
||||
device_option=op_dev
|
||||
)
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
data_cpu_value = workspace.FetchBlob(data_cpu)
|
||||
np.testing.assert_allclose(input, data_cpu_value)
|
@ -40,7 +40,7 @@ class TestGroupNormOp(hu.HypothesisTestCase):
|
||||
Y = gamma * (X - mu) / std + beta
|
||||
return [Y.reshape(dims), mu.reshape(N, G), (1.0 / std).reshape(N, G)]
|
||||
|
||||
@given(N=st.integers(1, 5), G=st.integers(1, 3), D=st.integers(1, 3),
|
||||
@given(N=st.integers(1, 5), G=st.integers(1, 5), D=st.integers(2, 2),
|
||||
H=st.integers(2, 5), W=st.integers(2, 5),
|
||||
epsilon=st.floats(min_value=1e-5, max_value=1e-4),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
|
||||
@ -60,8 +60,8 @@ class TestGroupNormOp(hu.HypothesisTestCase):
|
||||
X = np.random.randn(N, C, H, W).astype(np.float32) + 1.0
|
||||
else:
|
||||
X = np.random.randn(N, H, W, C).astype(np.float32) + 1.0
|
||||
gamma = np.random.rand(C).astype(np.float32) - 0.5
|
||||
beta = np.random.rand(C).astype(np.float32) - 0.5
|
||||
gamma = np.random.randn(C).astype(np.float32)
|
||||
beta = np.random.randn(C).astype(np.float32)
|
||||
inputs = [X, gamma, beta]
|
||||
|
||||
def ref_op(X, gamma, beta):
|
||||
@ -74,14 +74,12 @@ class TestGroupNormOp(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=inputs,
|
||||
reference=ref_op,
|
||||
threshold=5e-4,
|
||||
threshold=5e-3,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, inputs, [0, 1, 2])
|
||||
for i in range(len(inputs)):
|
||||
self.assertGradientChecks(gc, op, inputs, i, [0])
|
||||
|
||||
@given(N=st.integers(1, 5), G=st.integers(1, 3), D=st.integers(1, 3),
|
||||
T=st.integers(1, 3), H=st.integers(2, 5), W=st.integers(2, 5),
|
||||
@given(N=st.integers(1, 5), G=st.integers(1, 3), D=st.integers(2, 3),
|
||||
T=st.integers(2, 4), H=st.integers(2, 4), W=st.integers(2, 4),
|
||||
epsilon=st.floats(min_value=1e-5, max_value=1e-4),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
|
||||
def test_group_norm_3d(
|
||||
@ -100,8 +98,8 @@ class TestGroupNormOp(hu.HypothesisTestCase):
|
||||
X = np.random.randn(N, C, T, H, W).astype(np.float32) + 1.0
|
||||
else:
|
||||
X = np.random.randn(N, T, H, W, C).astype(np.float32) + 1.0
|
||||
gamma = np.random.rand(C).astype(np.float32) - 0.5
|
||||
beta = np.random.rand(C).astype(np.float32) - 0.5
|
||||
gamma = np.random.randn(C).astype(np.float32)
|
||||
beta = np.random.randn(C).astype(np.float32)
|
||||
inputs = [X, gamma, beta]
|
||||
|
||||
def ref_op(X, gamma, beta):
|
||||
@ -114,8 +112,34 @@ class TestGroupNormOp(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=inputs,
|
||||
reference=ref_op,
|
||||
threshold=5e-4,
|
||||
threshold=5e-3,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, inputs, [0, 1, 2])
|
||||
|
||||
@given(N=st.integers(1, 5), G=st.integers(1, 5), D=st.integers(2, 2),
|
||||
H=st.integers(2, 5), W=st.integers(2, 5),
|
||||
epsilon=st.floats(min_value=1e-5, max_value=1e-4),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
|
||||
def test_group_norm_grad(
|
||||
self, N, G, D, H, W, epsilon, order, gc, dc):
|
||||
op = core.CreateOperator(
|
||||
"GroupNorm",
|
||||
["X", "gamma", "beta"],
|
||||
["Y", "mean", "inv_std"],
|
||||
group=G,
|
||||
epsilon=epsilon,
|
||||
order=order,
|
||||
)
|
||||
|
||||
C = G * D
|
||||
X = np.arange(N * C * H * W).astype(np.float32)
|
||||
np.random.shuffle(X)
|
||||
if order == "NCHW":
|
||||
X = X.reshape((N, C, H, W))
|
||||
else:
|
||||
X = X.reshape((N, H, W, C))
|
||||
gamma = np.random.randn(C).astype(np.float32)
|
||||
beta = np.random.randn(C).astype(np.float32)
|
||||
inputs = [X, gamma, beta]
|
||||
for i in range(len(inputs)):
|
||||
self.assertGradientChecks(gc, op, inputs, i, [0])
|
||||
|
@ -7,9 +7,12 @@ from caffe2.python import core
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
|
||||
from hypothesis import given
|
||||
import hypothesis.strategies as st
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
from functools import partial
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestLearningRate(hu.HypothesisTestCase):
|
||||
@ -78,6 +81,110 @@ class TestLearningRate(hu.HypothesisTestCase):
|
||||
)
|
||||
self.assertReferenceChecks(gc, op, [iter], ref)
|
||||
|
||||
@given(gc=hu.gcs['gc'],
|
||||
min_num_iter=st.integers(min_value=10, max_value=20),
|
||||
max_num_iter=st.integers(min_value=50, max_value=100))
|
||||
def test_composite_learning_rate_op(self, gc, min_num_iter, max_num_iter):
|
||||
np.random.seed(65535)
|
||||
# Generate the iteration numbers for sub policy
|
||||
# The four sub policies are as follows:
|
||||
# 1. exp; 2. step; 3. fix; 4. exp
|
||||
num_lr_policy = 4
|
||||
iter_nums = np.random.randint(
|
||||
low=min_num_iter, high=max_num_iter, size=num_lr_policy)
|
||||
accu_iter_num = copy.deepcopy(iter_nums)
|
||||
for i in range(1, num_lr_policy):
|
||||
accu_iter_num[i] += accu_iter_num[i - 1]
|
||||
total_iter_nums = accu_iter_num[-1]
|
||||
|
||||
policy_lr_scale = np.random.uniform(low=2.0, high=2.0, size=num_lr_policy)
|
||||
|
||||
# args for StepLRPolicy
|
||||
step_size = np.random.randint(low=2, high=min_num_iter // 2)
|
||||
step_gamma = np.random.random()
|
||||
# args for ExpLRPolicy
|
||||
exp_gamma = np.random.random()
|
||||
# common args
|
||||
base_lr = 0.1
|
||||
|
||||
# StepLRPolicy
|
||||
def step_lr(iter, lr_scale):
|
||||
return math.pow(step_gamma, iter // step_size) * lr_scale
|
||||
|
||||
# ExpLRPolicy
|
||||
def exp_lr(iter, lr_scale):
|
||||
return math.pow(exp_gamma, iter) * lr_scale
|
||||
|
||||
# FixedLRPolicy
|
||||
def fixed_lr(iter, lr_scale):
|
||||
return lr_scale
|
||||
|
||||
# test one sub policy case
|
||||
def one_policy_check_ref(iter, lr_scale):
|
||||
iter = int(iter)
|
||||
exp_lr_val = exp_lr(iter, lr_scale=lr_scale)
|
||||
return (np.array(base_lr * exp_lr_val), )
|
||||
|
||||
op = core.CreateOperator(
|
||||
'LearningRate',
|
||||
'data',
|
||||
'out',
|
||||
policy='composite',
|
||||
sub_policy_num_iters=iter_nums[:1],
|
||||
sub_policy_0_lr_scale=policy_lr_scale[0],
|
||||
sub_policy_0_policy='exp',
|
||||
sub_policy_0_gamma=exp_gamma,
|
||||
base_lr=base_lr,
|
||||
)
|
||||
for iter_idx in range(1, total_iter_nums + 1):
|
||||
self.assertReferenceChecks(
|
||||
gc, op, [np.asarray([iter_idx])],
|
||||
partial(one_policy_check_ref, lr_scale=policy_lr_scale[0]))
|
||||
|
||||
# all the case with all four sub policies
|
||||
def all_sub_policy_check_ref(iter, lr_scale):
|
||||
assert iter <= accu_iter_num[3]
|
||||
if iter <= accu_iter_num[0]:
|
||||
lr = exp_lr(iter, lr_scale=lr_scale)
|
||||
elif iter <= accu_iter_num[1]:
|
||||
lr = step_lr(iter, lr_scale=lr_scale)
|
||||
elif iter <= accu_iter_num[2]:
|
||||
lr = fixed_lr(iter, lr_scale=lr_scale)
|
||||
else:
|
||||
lr = exp_lr(iter, lr_scale=lr_scale)
|
||||
return (np.array(base_lr * lr), )
|
||||
|
||||
op = core.CreateOperator(
|
||||
'LearningRate',
|
||||
'data',
|
||||
'out',
|
||||
policy='composite',
|
||||
sub_policy_num_iters=iter_nums,
|
||||
sub_policy_0_policy='exp',
|
||||
sub_policy_0_lr_scale=policy_lr_scale[0],
|
||||
sub_policy_0_gamma=exp_gamma,
|
||||
sub_policy_1_policy='step',
|
||||
sub_policy_1_lr_scale=policy_lr_scale[1],
|
||||
sub_policy_1_stepsize=step_size,
|
||||
sub_policy_1_gamma=step_gamma,
|
||||
sub_policy_2_policy='fixed',
|
||||
sub_policy_2_lr_scale=policy_lr_scale[2],
|
||||
sub_policy_3_policy='exp',
|
||||
sub_policy_3_gamma=exp_gamma,
|
||||
sub_policy_3_lr_scale=policy_lr_scale[3],
|
||||
base_lr=base_lr,
|
||||
)
|
||||
|
||||
iter_policy = 0
|
||||
for iter_idx in range(1, total_iter_nums + 1):
|
||||
if iter_idx > accu_iter_num[iter_policy]:
|
||||
iter_policy += 1
|
||||
self.assertReferenceChecks(
|
||||
gc, op, [np.asarray([iter_idx])],
|
||||
partial(all_sub_policy_check_ref,
|
||||
lr_scale=policy_lr_scale[iter_policy])
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
51
caffe2/python/operator_test/relu_n_op_test.py
Normal file
51
caffe2/python/operator_test/relu_n_op_test.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2016-present, Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
##############################################################################
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core
|
||||
from hypothesis import given
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import numpy as np
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestRelu(hu.HypothesisTestCase):
|
||||
|
||||
@given(X=hu.tensor(),
|
||||
**hu.gcs)
|
||||
def test_relu_n(self, X, gc, dc):
|
||||
X = 0.8 * np.sign(X)
|
||||
X = X - 0.5
|
||||
X[X == 0.0] = 0.01
|
||||
n = max(np.max(X), 1.0) / 2
|
||||
X[X == 0.2] = 0.01
|
||||
|
||||
def relu_n_ref(X):
|
||||
Y = np.minimum(np.maximum(X, 0), n)
|
||||
return [Y]
|
||||
|
||||
op = core.CreateOperator("ReluN", ["X"], ["Y"], n=n)
|
||||
self.assertReferenceChecks(gc, op, [X], relu_n_ref)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.001, threshold=0.001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -486,6 +486,13 @@ class TestShapeInference(test_util.TestCase):
|
||||
workspace.FeedBlob('x', np.random.rand(1, 2, 3, 4).astype(np.float32))
|
||||
self.InferTensorRunAndCompare(model)
|
||||
|
||||
def testInt8Conversion(self):
|
||||
model = model_helper.ModelHelper(name="int8_conversion_test")
|
||||
model.FloatToFused8BitRowwiseQuantized('x', 'x_8bit')
|
||||
model.Fused8BitRowwiseQuantizedToFloat('x_8bit', 'x_recovered')
|
||||
workspace.FeedBlob('x', np.random.rand(100, 150).astype(np.float32))
|
||||
self.InferTensorRunAndCompare(model)
|
||||
|
||||
def InferTensorRunAndCompare(self, model, expected_uninferred_blobs=None):
|
||||
'''
|
||||
Runs shape inference, and then the model to check
|
||||
|
@ -28,7 +28,8 @@ class Optimizer(object):
|
||||
self._instance_num = _optimizer_instance_count[self.__class__.__name__]
|
||||
_optimizer_instance_count[self.__class__.__name__] += 1
|
||||
self._lr_multiplier = None
|
||||
self._lr_multiplier_on_gpu = False
|
||||
self._local_lr_multiplier = None
|
||||
self._local_lr_multiplier_on_gpu = False
|
||||
|
||||
'''
|
||||
Adds optimization operators to the net for given parameter and its gradient
|
||||
@ -117,29 +118,59 @@ class Optimizer(object):
|
||||
lr = net.GetBlobRef(learning_rate_blob)
|
||||
|
||||
if self._lr_multiplier is not None:
|
||||
current_scope = scope.CurrentDeviceScope()
|
||||
if (current_scope is not None
|
||||
and current_scope.device_type == caffe2_pb2.CUDA
|
||||
and not self._lr_multiplier_on_gpu):
|
||||
lr_multiplier = net.CopyFromCPUInput(
|
||||
self._lr_multiplier,
|
||||
self.make_unique_blob_name('lr_multiplier')
|
||||
)
|
||||
else:
|
||||
lr_multiplier = self._lr_multiplier
|
||||
lr_multiplier = net.CopyFromCPUInput(
|
||||
self._lr_multiplier, self.make_unique_blob_name('lr_multiplier')
|
||||
)
|
||||
|
||||
scaled_lr = net.Mul(
|
||||
lr = net.Mul(
|
||||
[lr, lr_multiplier],
|
||||
self.make_unique_blob_name('scaled_lr'),
|
||||
broadcast=1,
|
||||
)
|
||||
lr = scaled_lr
|
||||
|
||||
if self._local_lr_multiplier is not None:
|
||||
current_scope = scope.CurrentDeviceScope()
|
||||
if (current_scope is not None
|
||||
and current_scope.device_type == caffe2_pb2.CUDA
|
||||
and not self._local_lr_multiplier_on_gpu):
|
||||
local_lr_multiplier = net.CopyFromCPUInput(
|
||||
self._local_lr_multiplier,
|
||||
self.make_unique_blob_name('local_lr_multiplier')
|
||||
)
|
||||
else:
|
||||
local_lr_multiplier = self._local_lr_multiplier
|
||||
|
||||
lr = net.Mul(
|
||||
[lr, local_lr_multiplier],
|
||||
self.make_unique_blob_name('local_scaled_lr'),
|
||||
broadcast=1,
|
||||
)
|
||||
|
||||
return lr, iteration
|
||||
|
||||
def add_lr_multiplier(self, lr_multiplier, is_gpu_blob=False):
|
||||
def add_lr_multiplier(self, lr_multiplier):
|
||||
"""
|
||||
Set the global learning rate multiplier. If a multiplier already
|
||||
existed, this will overwrite the existing multiplier. The multiplier is
|
||||
used for all future calls to _run(), unless it is overwritten.
|
||||
"""
|
||||
self._lr_multiplier = lr_multiplier
|
||||
self._lr_multiplier_on_gpu = is_gpu_blob
|
||||
|
||||
def _add_local_lr_multiplier(self, local_lr_multiplier, is_gpu_blob=False):
|
||||
"""
|
||||
Set the local learning rate multiplier. This local multiplier is
|
||||
multiplied with the global learning rate multiplier if it exists. As
|
||||
with the global learning rate multiplier, this multiplier will be
|
||||
used for all future calls to _run(), so please call
|
||||
_clear_local_lr_multiplier() at the beginning of the optimizer's _run()
|
||||
before optionally calling this function.
|
||||
"""
|
||||
self._local_lr_multiplier = local_lr_multiplier
|
||||
self._local_lr_multiplier_on_gpu = is_gpu_blob
|
||||
|
||||
def _clear_local_lr_multiplier(self):
|
||||
self._local_lr_multiplier = None
|
||||
self._local_lr_multiplier_on_gpu = False
|
||||
|
||||
@staticmethod
|
||||
def dedup(net, sparse_dedup_aggregator, grad):
|
||||
@ -206,6 +237,8 @@ class SgdOptimizer(Optimizer):
|
||||
"Expect positive base learning rate, got {}".format(
|
||||
self.base_learning_rate))
|
||||
|
||||
self._clear_local_lr_multiplier()
|
||||
|
||||
# TODO(zqq): support LARS for sparse parameters
|
||||
if self.lars is not None and not isinstance(grad, core.GradientSlice):
|
||||
assert self.lars >= 0, (
|
||||
@ -215,7 +248,7 @@ class SgdOptimizer(Optimizer):
|
||||
self.make_unique_blob_name(str(param) + "_lars"),
|
||||
offset=self.lars)
|
||||
current_scope = scope.CurrentDeviceScope()
|
||||
self.add_lr_multiplier(
|
||||
self._add_local_lr_multiplier(
|
||||
lr_lars_multiplier,
|
||||
is_gpu_blob=(current_scope is not None
|
||||
and current_scope.device_type == caffe2_pb2.CUDA),
|
||||
@ -492,6 +525,8 @@ class AdagradOptimizer(Optimizer):
|
||||
if self.alpha <= 0:
|
||||
return
|
||||
|
||||
self._clear_local_lr_multiplier()
|
||||
|
||||
if self.lars is not None and not isinstance(grad, core.GradientSlice):
|
||||
assert self.lars >= 0, (
|
||||
'Lars offset must be nonnegative, got {}'.format(self.lars))
|
||||
@ -500,7 +535,7 @@ class AdagradOptimizer(Optimizer):
|
||||
self.make_unique_blob_name(str(param) + "_lars"),
|
||||
offset=self.lars)
|
||||
current_scope = scope.CurrentDeviceScope()
|
||||
self.add_lr_multiplier(
|
||||
self._add_local_lr_multiplier(
|
||||
lr_lars_multiplier,
|
||||
is_gpu_blob=(current_scope is not None
|
||||
and current_scope.device_type == caffe2_pb2.CUDA),
|
||||
|
@ -1,6 +1,9 @@
|
||||
#ifndef CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_
|
||||
#define CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_
|
||||
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
@ -165,6 +168,40 @@ class HillLearningRate : public LearningRateFunctor<T> {
|
||||
T end_multiplier_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CompositeLearningRateItem {
|
||||
public:
|
||||
CompositeLearningRateItem(int64_t num_iter, LearningRateFunctor<T>* policy)
|
||||
: num_iter_(num_iter), policy_(policy) {}
|
||||
int64_t num_iter_;
|
||||
LearningRateFunctor<T>* policy_;
|
||||
};
|
||||
|
||||
// composite: the learning policy changes according to current iteration #
|
||||
template <typename T>
|
||||
class CompositeLearningRate : public LearningRateFunctor<T> {
|
||||
public:
|
||||
CompositeLearningRate(
|
||||
const std::list<CompositeLearningRateItem<T>>& sub_policies) {
|
||||
DCHECK_GT(sub_policies.size(), 0);
|
||||
int64_t num_iter_start = 1;
|
||||
for (auto it = sub_policies.begin(); it != sub_policies.end(); ++it) {
|
||||
DCHECK_GT(it->num_iter_, 0);
|
||||
sub_policies_[num_iter_start].reset(it->policy_);
|
||||
num_iter_start += it->num_iter_;
|
||||
}
|
||||
}
|
||||
T operator()(const int64_t iter) const override {
|
||||
auto sub_policy = sub_policies_.upper_bound(iter);
|
||||
DCHECK(sub_policy != sub_policies_.begin());
|
||||
--sub_policy;
|
||||
return (*sub_policy->second)(iter);
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int64_t, std::unique_ptr<LearningRateFunctor<T>>> sub_policies_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_
|
||||
|
@ -13,18 +13,23 @@ more exponential. Learning rate is controlled by the following arguments:
|
||||
|
||||
|
||||
Required:
|
||||
`iterations`
|
||||
`base_lr`: base learning rate
|
||||
`policy`: this controls how the learning rate is applied, options are:
|
||||
`fixed`
|
||||
`step`: uses `stepsize`, `gamma`
|
||||
`exp`: uses `gamma`
|
||||
`inv`: uses `gamma`, `power`
|
||||
`linearWarmup`: uses `start_multiplier`, `num_iter`
|
||||
`constantWarmup`: uses `multiplier`, `num_iter`
|
||||
`alter`: uses `active_first`, `active_period`, `inactive_period`
|
||||
`hill`: uses those in both `linearWarmup` and `inv`, plus `end_multiplier`
|
||||
|
||||
`iterations`
|
||||
`base_lr`: base learning rate
|
||||
`policy`: this controls how the learning rate is applied, options are:
|
||||
`fixed`
|
||||
`step`: uses `stepsize`, `gamma`
|
||||
`exp`: uses `gamma`
|
||||
`inv`: uses `gamma`, `power`
|
||||
`linearWarmup`: uses `start_multiplier`, `num_iter`
|
||||
`constantWarmup`: uses `multiplier`, `num_iter`
|
||||
`alter`: uses `active_first`, `active_period`, `inactive_period`
|
||||
`hill`: uses those in both `linearWarmup` and `inv`, plus `end_multiplier`
|
||||
`composite`: uses `sub_policy_num_iters` and additional args with format
|
||||
sub_policy_{sub_policy_index}_{sub_policy_arg}, for example:
|
||||
sub_policy_0_policy: "exp", sub_policy_0_gamma: 0.99,
|
||||
sub_policy_0_lr_scale: 1.2
|
||||
sub_policy_0_policy: "fixed", sub_policy_0_lr_scale: 1.0
|
||||
sub_policy_num_iters: [1000, 1000]
|
||||
|
||||
Optional:
|
||||
`stepsize`: defaults to 0
|
||||
@ -67,6 +72,9 @@ Example usage:
|
||||
.Arg(
|
||||
"multiplier",
|
||||
"(float, default 0.5) constant multiplier for learning rate")
|
||||
.Arg(
|
||||
"sub_policy_num_iters",
|
||||
"(int array, default empty) number of iterations for each sub learning rate policy in composite policy")
|
||||
.Input(0, "input", "description needed")
|
||||
.Output(0, "output", "description needed")
|
||||
.DeviceInferenceFunction([](const OperatorDef& def) {
|
||||
|
@ -21,85 +21,14 @@ class LearningRateOp final : public Operator<Context> {
|
||||
CAFFE_ENFORCE_NE(base_lr_, FLT_MAX, "Base learning rate must be set.");
|
||||
const string policy = OperatorBase::GetSingleArgument<string>("policy", "");
|
||||
CAFFE_ENFORCE(policy.size(), "Must specify a learning rate policy.");
|
||||
if (policy == "fixed") {
|
||||
functor_.reset(new FixedLearningRate<T>());
|
||||
} else if (policy == "alter") {
|
||||
bool active_first =
|
||||
OperatorBase::template GetSingleArgument<bool>("active_first", true);
|
||||
int64_t active_period = OperatorBase::template GetSingleArgument<int64_t>(
|
||||
"active_period", -1);
|
||||
int64_t inactive_period =
|
||||
OperatorBase::template GetSingleArgument<int64_t>(
|
||||
"inactive_period", -1);
|
||||
DCHECK_GE(active_period, 0);
|
||||
DCHECK_GE(inactive_period, 0);
|
||||
functor_.reset(new AlternateLearningRate<T>(
|
||||
active_period, inactive_period, active_first));
|
||||
} else if (policy == "hill") {
|
||||
int64_t num_iter =
|
||||
OperatorBase::template GetSingleArgument<int>("num_iter", 0);
|
||||
DCHECK_GT(num_iter, 0);
|
||||
T start_multiplier = OperatorBase::template GetSingleArgument<float>(
|
||||
"start_multiplier", 0.);
|
||||
DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
|
||||
DCHECK_LE(start_multiplier, 1);
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
T power = OperatorBase::template GetSingleArgument<float>("power", 0);
|
||||
DCHECK_GT(power, 0);
|
||||
T end_multiplier =
|
||||
OperatorBase::template GetSingleArgument<float>("end_multiplier", 0);
|
||||
DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
|
||||
DCHECK_LE(end_multiplier, 1);
|
||||
functor_.reset(new HillLearningRate<T>(
|
||||
num_iter, start_multiplier, gamma, power, end_multiplier));
|
||||
} else if (policy == "step") {
|
||||
int stepsize =
|
||||
OperatorBase::template GetSingleArgument<int>("stepsize", 0);
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
|
||||
DCHECK_GT(stepsize, 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
functor_.reset(new StepLearningRate<T>(stepsize, gamma));
|
||||
} else if (policy == "exp") {
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
functor_.reset(new ExpLearningRate<T>(gamma));
|
||||
} else if (policy == "inv") {
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>("gamma", 0);
|
||||
T power = OperatorBase::template GetSingleArgument<float>("power", 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
DCHECK_GT(power, 0);
|
||||
functor_.reset(new InvLearningRate<T>(gamma, power));
|
||||
} else if (policy == "poly") {
|
||||
int max_iter = OperatorBase::template GetSingleArgument<int>("max_iter", -1);
|
||||
T power = OperatorBase::template GetSingleArgument<float>("power", 0);
|
||||
DCHECK_GT(power, 0);
|
||||
functor_.reset(new PolyLearningRate<T>(power, max_iter));
|
||||
} else if (policy == "linearWarmup") {
|
||||
T start_multiplier = OperatorBase::template GetSingleArgument<float>(
|
||||
"start_multiplier", 0.);
|
||||
int num_iter =
|
||||
OperatorBase::template GetSingleArgument<int>("num_iter", 0);
|
||||
DCHECK_GT(start_multiplier, 0);
|
||||
functor_.reset(
|
||||
new LinearWarmupLearningRate<T>(start_multiplier, num_iter));
|
||||
} else if (policy == "constantWarmup") {
|
||||
T multiplier =
|
||||
OperatorBase::template GetSingleArgument<float>("multiplier", 0.5);
|
||||
int num_iter =
|
||||
OperatorBase::template GetSingleArgument<int>("num_iter", 0);
|
||||
DCHECK_GT(multiplier, 0);
|
||||
functor_.reset(new ConstantWarmupLearningRate<T>(multiplier, num_iter));
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown learning rate policy: " << policy;
|
||||
}
|
||||
functor_.reset(createLearningRateFunctor(policy));
|
||||
}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override {
|
||||
int64_t iter =
|
||||
OperatorBase::Input<TensorCPU>(0).template data<int64_t>()[0];
|
||||
T learning_rate = base_lr_ * (*functor_)(iter);
|
||||
T learning_rate = cur_base_lr_ * (*functor_)(iter);
|
||||
// Write to output.
|
||||
auto* output = Output(0);
|
||||
output->Resize(vector<TIndex>());
|
||||
@ -109,11 +38,131 @@ class LearningRateOp final : public Operator<Context> {
|
||||
}
|
||||
|
||||
private:
|
||||
unique_ptr<LearningRateFunctor<T> > functor_;
|
||||
unique_ptr<LearningRateFunctor<T>> functor_;
|
||||
T base_lr_;
|
||||
T base_lr_scale_;
|
||||
T cur_base_lr_;
|
||||
|
||||
LearningRateFunctor<T>* createLearningRateFunctor(
|
||||
const string& policy,
|
||||
const string& arg_prefix = "") {
|
||||
if (policy != "composite") {
|
||||
base_lr_scale_ =
|
||||
OperatorBase::GetSingleArgument<float>(arg_prefix + "lr_scale", 1.0);
|
||||
cur_base_lr_ = base_lr_scale_ * base_lr_;
|
||||
}
|
||||
if (policy == "fixed") {
|
||||
return new FixedLearningRate<T>();
|
||||
} else if (policy == "alter") {
|
||||
bool active_first = OperatorBase::template GetSingleArgument<bool>(
|
||||
arg_prefix + "active_first", true);
|
||||
int64_t active_period = OperatorBase::template GetSingleArgument<int64_t>(
|
||||
arg_prefix + "active_period", -1);
|
||||
int64_t inactive_period =
|
||||
OperatorBase::template GetSingleArgument<int64_t>(
|
||||
arg_prefix + "inactive_period", -1);
|
||||
DCHECK_GE(active_period, 0);
|
||||
DCHECK_GE(inactive_period, 0);
|
||||
return new AlternateLearningRate<T>(
|
||||
active_period, inactive_period, active_first);
|
||||
} else if (policy == "hill") {
|
||||
int64_t num_iter = OperatorBase::template GetSingleArgument<int>(
|
||||
arg_prefix + "num_iter", 0);
|
||||
DCHECK_GT(num_iter, 0);
|
||||
T start_multiplier = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "start_multiplier", 0.);
|
||||
DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
|
||||
DCHECK_LE(start_multiplier, 1);
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "gamma", 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
T power = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "power", 0);
|
||||
DCHECK_GT(power, 0);
|
||||
T end_multiplier = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "end_multiplier", 0);
|
||||
DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
|
||||
DCHECK_LE(end_multiplier, 1);
|
||||
return new HillLearningRate<T>(
|
||||
num_iter, start_multiplier, gamma, power, end_multiplier);
|
||||
} else if (policy == "step") {
|
||||
int stepsize = OperatorBase::template GetSingleArgument<int>(
|
||||
arg_prefix + "stepsize", 0);
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "gamma", 0);
|
||||
DCHECK_GT(stepsize, 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
return new StepLearningRate<T>(stepsize, gamma);
|
||||
} else if (policy == "exp") {
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "gamma", 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
return new ExpLearningRate<T>(gamma);
|
||||
} else if (policy == "inv") {
|
||||
T gamma = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "gamma", 0);
|
||||
T power = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "power", 0);
|
||||
DCHECK_GT(gamma, 0);
|
||||
DCHECK_GT(power, 0);
|
||||
return new InvLearningRate<T>(gamma, power);
|
||||
} else if (policy == "poly") {
|
||||
int max_iter = OperatorBase::template GetSingleArgument<int>(
|
||||
arg_prefix + "max_iter", -1);
|
||||
T power = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "power", 0);
|
||||
DCHECK_GT(power, 0);
|
||||
return new PolyLearningRate<T>(power, max_iter);
|
||||
} else if (policy == "linearWarmup") {
|
||||
T start_multiplier = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "start_multiplier", 0.);
|
||||
int num_iter = OperatorBase::template GetSingleArgument<int>(
|
||||
arg_prefix + "num_iter", 0);
|
||||
DCHECK_GT(start_multiplier, 0);
|
||||
return new LinearWarmupLearningRate<T>(start_multiplier, num_iter);
|
||||
} else if (policy == "constantWarmup") {
|
||||
T multiplier = OperatorBase::template GetSingleArgument<float>(
|
||||
arg_prefix + "multiplier", 0.5);
|
||||
int num_iter = OperatorBase::template GetSingleArgument<int>(
|
||||
arg_prefix + "num_iter", 0);
|
||||
DCHECK_GT(multiplier, 0);
|
||||
return new ConstantWarmupLearningRate<T>(multiplier, num_iter);
|
||||
} else if (policy == "composite") {
|
||||
std::vector<int> sub_policy_num_iters =
|
||||
OperatorBase::template GetRepeatedArgument<int>(
|
||||
"sub_policy_num_iters");
|
||||
std::list<CompositeLearningRateItem<T>> sub_policies;
|
||||
CAFFE_ENFORCE_GT(
|
||||
sub_policy_num_iters.size(),
|
||||
0,
|
||||
"Must specify at least one sub learning rate policy.");
|
||||
for (int i = 0; i < sub_policy_num_iters.size(); ++i) {
|
||||
CAFFE_ENFORCE_GT(
|
||||
sub_policy_num_iters[i],
|
||||
0,
|
||||
"The number of iterations for sub learning rate policy should be positive.");
|
||||
std::stringstream sub_policy_arg_prefix;
|
||||
sub_policy_arg_prefix << "sub_policy_" << i << "_";
|
||||
const string sub_policy_arg_prefix_str = sub_policy_arg_prefix.str();
|
||||
const string sub_policy = OperatorBase::GetSingleArgument<string>(
|
||||
sub_policy_arg_prefix_str + "policy", "");
|
||||
if (sub_policy == "composite") {
|
||||
CAFFE_THROW(
|
||||
"Defining composite LR policy as a subpolicy of composite LR "
|
||||
"policy is not allowed.");
|
||||
}
|
||||
sub_policies.push_back(CompositeLearningRateItem<T>(
|
||||
sub_policy_num_iters[i],
|
||||
createLearningRateFunctor(sub_policy, sub_policy_arg_prefix_str)));
|
||||
}
|
||||
return new CompositeLearningRate<T>(sub_policies);
|
||||
} else {
|
||||
CAFFE_THROW("Unknown learning rate policy: ", policy);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
|
||||
#endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
|
||||
|
@ -28,9 +28,11 @@ struct DepthwiseArgs {
|
||||
|
||||
#ifdef __ARM_NEON__
|
||||
|
||||
static inline void
|
||||
winograd_f2k3_input_transform_inplace__neon(float32x4_t *d0, float32x4_t *d1,
|
||||
float32x4_t *d2, float32x4_t *d3) {
|
||||
static inline void winograd_f2k3_input_transform_inplace__neon(
|
||||
float32x4_t* d0,
|
||||
float32x4_t* d1,
|
||||
float32x4_t* d2,
|
||||
float32x4_t* d3) {
|
||||
//*d7 = wd7;
|
||||
float32x4_t wd0 = *d0 - *d2;
|
||||
float32x4_t wd1 = *d1 + *d2;
|
||||
@ -42,15 +44,17 @@ winograd_f2k3_input_transform_inplace__neon(float32x4_t *d0, float32x4_t *d1,
|
||||
*d3 = wd3;
|
||||
}
|
||||
|
||||
static inline void
|
||||
winograd_f2k3_output_transform_inplace__neon(float32x4_t *m0, float32x4_t *m1,
|
||||
float32x4_t *m2, float32x4_t *m3) {
|
||||
static inline void winograd_f2k3_output_transform_inplace__neon(
|
||||
float32x4_t* m0,
|
||||
float32x4_t* m1,
|
||||
float32x4_t* m2,
|
||||
float32x4_t* m3) {
|
||||
*m0 = *m0 + *m1 + *m2;
|
||||
*m1 = *m1 - *m2 - *m3;
|
||||
}
|
||||
|
||||
static inline float32x4_t vmuladdq_f32(float32x4_t c, float32x4_t a,
|
||||
float32x4_t b) {
|
||||
static inline float32x4_t
|
||||
vmuladdq_f32(float32x4_t c, float32x4_t a, float32x4_t b) {
|
||||
#if defined(__aarch64__)
|
||||
return vfmaq_f32(c, a, b);
|
||||
#else
|
||||
@ -58,8 +62,8 @@ static inline float32x4_t vmuladdq_f32(float32x4_t c, float32x4_t a,
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline float32x4_t vmulsubq_f32(float32x4_t c, float32x4_t a,
|
||||
float32x4_t b) {
|
||||
static inline float32x4_t
|
||||
vmulsubq_f32(float32x4_t c, float32x4_t a, float32x4_t b) {
|
||||
#if defined(__aarch64__)
|
||||
return vfmsq_f32(c, a, b);
|
||||
#else
|
||||
@ -68,9 +72,13 @@ static inline float32x4_t vmulsubq_f32(float32x4_t c, float32x4_t a,
|
||||
}
|
||||
|
||||
static inline void winograd_f2k3_kernel_transform__neon(
|
||||
const float32x4_t g0, const float32x4_t g1, const float32x4_t g2,
|
||||
float32x4_t *transform0, float32x4_t *transform1, float32x4_t *transform2,
|
||||
float32x4_t *transform3) {
|
||||
const float32x4_t g0,
|
||||
const float32x4_t g1,
|
||||
const float32x4_t g2,
|
||||
float32x4_t* transform0,
|
||||
float32x4_t* transform1,
|
||||
float32x4_t* transform2,
|
||||
float32x4_t* transform3) {
|
||||
const float32x4_t const_half = vdupq_n_f32(0.5f);
|
||||
float32x4_t half_g0_plus_g2 = const_half * (g0 + g2);
|
||||
*transform0 = g0;
|
||||
@ -81,13 +89,16 @@ static inline void winograd_f2k3_kernel_transform__neon(
|
||||
|
||||
static inline float32x4x4_t v4f_transpose4x4__neon(float32x4x4_t m) {
|
||||
float32x4x4_t ret;
|
||||
vst4q_f32((float *)(&ret), m);
|
||||
vst4q_f32((float*)(&ret), m);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
const float *kernel, const float *bias,
|
||||
float *output) {
|
||||
void runDepthwise3x3Conv(
|
||||
const DepthwiseArgs& args,
|
||||
const float* input,
|
||||
const float* kernel,
|
||||
const float* bias,
|
||||
float* output) {
|
||||
const float32x4_t vbias = vsetq_lane_f32(*bias, vdupq_n_f32(0.0), 1);
|
||||
float32x4x4_t kernel_tile;
|
||||
{
|
||||
@ -97,36 +108,49 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
const float32x4_t g2 =
|
||||
vextq_f32(vld1q_f32(kernel + 5), vld1q_f32(kernel + 5), 1);
|
||||
float32x4x4_t w;
|
||||
winograd_f2k3_kernel_transform__neon(g0, g1, g2, &w.val[0], &w.val[1],
|
||||
&w.val[2], &w.val[3]);
|
||||
winograd_f2k3_kernel_transform__neon(
|
||||
g0, g1, g2, &w.val[0], &w.val[1], &w.val[2], &w.val[3]);
|
||||
w = v4f_transpose4x4__neon(w);
|
||||
|
||||
winograd_f2k3_kernel_transform__neon(
|
||||
w.val[0], w.val[1], w.val[2], &kernel_tile.val[0], &kernel_tile.val[1],
|
||||
&kernel_tile.val[2], &kernel_tile.val[3]);
|
||||
w.val[0],
|
||||
w.val[1],
|
||||
w.val[2],
|
||||
&kernel_tile.val[0],
|
||||
&kernel_tile.val[1],
|
||||
&kernel_tile.val[2],
|
||||
&kernel_tile.val[3]);
|
||||
}
|
||||
|
||||
#define TILE \
|
||||
winograd_f2k3_input_transform_inplace__neon( \
|
||||
&input_tile.val[0], &input_tile.val[1], &input_tile.val[2], \
|
||||
&input_tile.val[3]); \
|
||||
input_tile = v4f_transpose4x4__neon(input_tile); \
|
||||
winograd_f2k3_input_transform_inplace__neon( \
|
||||
&input_tile.val[0], &input_tile.val[1], &input_tile.val[2], \
|
||||
&input_tile.val[3]); \
|
||||
\
|
||||
for (int row = 0; row < 4; ++row) { \
|
||||
input_tile.val[row] = \
|
||||
vmulq_f32(input_tile.val[row], kernel_tile.val[row]); \
|
||||
} \
|
||||
\
|
||||
input_tile.val[1] = input_tile.val[1] + vbias; \
|
||||
winograd_f2k3_output_transform_inplace__neon( \
|
||||
&input_tile.val[0], &input_tile.val[1], &input_tile.val[2], \
|
||||
&input_tile.val[3]); \
|
||||
input_tile = v4f_transpose4x4__neon(input_tile); \
|
||||
winograd_f2k3_output_transform_inplace__neon( \
|
||||
&input_tile.val[0], &input_tile.val[1], &input_tile.val[2], \
|
||||
#define TILE \
|
||||
winograd_f2k3_input_transform_inplace__neon( \
|
||||
&input_tile.val[0], \
|
||||
&input_tile.val[1], \
|
||||
&input_tile.val[2], \
|
||||
&input_tile.val[3]); \
|
||||
input_tile = v4f_transpose4x4__neon(input_tile); \
|
||||
winograd_f2k3_input_transform_inplace__neon( \
|
||||
&input_tile.val[0], \
|
||||
&input_tile.val[1], \
|
||||
&input_tile.val[2], \
|
||||
&input_tile.val[3]); \
|
||||
\
|
||||
for (int row = 0; row < 4; ++row) { \
|
||||
input_tile.val[row] = \
|
||||
vmulq_f32(input_tile.val[row], kernel_tile.val[row]); \
|
||||
} \
|
||||
\
|
||||
input_tile.val[1] = input_tile.val[1] + vbias; \
|
||||
winograd_f2k3_output_transform_inplace__neon( \
|
||||
&input_tile.val[0], \
|
||||
&input_tile.val[1], \
|
||||
&input_tile.val[2], \
|
||||
&input_tile.val[3]); \
|
||||
input_tile = v4f_transpose4x4__neon(input_tile); \
|
||||
winograd_f2k3_output_transform_inplace__neon( \
|
||||
&input_tile.val[0], \
|
||||
&input_tile.val[1], \
|
||||
&input_tile.val[2], \
|
||||
&input_tile.val[3])
|
||||
|
||||
// Non-padded regime.
|
||||
@ -139,11 +163,11 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
int ih = oth * 2 - args.pad_rows;
|
||||
int iw = otw * 2 - args.pad_cols;
|
||||
// fast-path, all accesses in-bounds
|
||||
if (__builtin_expect(ih >= 0 && iw >= 0 && ih + 3 < args.in_rows &&
|
||||
iw + 3 < args.in_cols &&
|
||||
2 * oth + 1 < args.out_rows &&
|
||||
2 * otw + 1 < args.out_cols,
|
||||
1)) {
|
||||
if (__builtin_expect(
|
||||
ih >= 0 && iw >= 0 && ih + 3 < args.in_rows &&
|
||||
iw + 3 < args.in_cols && 2 * oth + 1 < args.out_rows &&
|
||||
2 * otw + 1 < args.out_cols,
|
||||
1)) {
|
||||
float32x4x4_t input_tile;
|
||||
for (int row = 0; row < 4; ++row) {
|
||||
input_tile.val[row] =
|
||||
@ -153,8 +177,9 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
TILE;
|
||||
|
||||
for (size_t row = 0; row < 2; ++row) {
|
||||
vst1_f32(output + (oth * 2 + row) * args.out_cols + otw * 2,
|
||||
vget_low_f32(input_tile.val[row]));
|
||||
vst1_f32(
|
||||
output + (oth * 2 + row) * args.out_cols + otw * 2,
|
||||
vget_low_f32(input_tile.val[row]));
|
||||
}
|
||||
} else {
|
||||
float block[4][4];
|
||||
@ -200,12 +225,12 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
typedef float psimd_f32 __attribute__((vector_size(16), aligned(1)));
|
||||
typedef int psimd_s32 __attribute__((__vector_size__(16)));
|
||||
|
||||
PSIMD_INTRINSIC void psimd_store_f32(void *address, psimd_f32 value) {
|
||||
*((psimd_f32 *)address) = value;
|
||||
PSIMD_INTRINSIC void psimd_store_f32(void* address, psimd_f32 value) {
|
||||
*((psimd_f32*)address) = value;
|
||||
}
|
||||
|
||||
PSIMD_INTRINSIC psimd_f32 psimd_load_f32(const void *address) {
|
||||
return *((const psimd_f32 *)address);
|
||||
PSIMD_INTRINSIC psimd_f32 psimd_load_f32(const void* address) {
|
||||
return *((const psimd_f32*)address);
|
||||
}
|
||||
|
||||
PSIMD_INTRINSIC psimd_f32 psimd_splat_f32(float c) {
|
||||
@ -249,12 +274,15 @@ PSIMD_INTRINSIC psimd_f32 psimd_concat_hi_f32(psimd_f32 a, psimd_f32 b) {
|
||||
|
||||
#endif
|
||||
|
||||
static inline void psimd_transpose4x4_f32(const psimd_f32 row0,
|
||||
const psimd_f32 row1,
|
||||
const psimd_f32 row2,
|
||||
const psimd_f32 row3, psimd_f32 *col0,
|
||||
psimd_f32 *col1, psimd_f32 *col2,
|
||||
psimd_f32 *col3) {
|
||||
static inline void psimd_transpose4x4_f32(
|
||||
const psimd_f32 row0,
|
||||
const psimd_f32 row1,
|
||||
const psimd_f32 row2,
|
||||
const psimd_f32 row3,
|
||||
psimd_f32* col0,
|
||||
psimd_f32* col1,
|
||||
psimd_f32* col2,
|
||||
psimd_f32* col3) {
|
||||
const psimd_f32 row01lo = psimd_interleave_lo_f32(row0, row1);
|
||||
const psimd_f32 row01hi = psimd_interleave_hi_f32(row0, row1);
|
||||
const psimd_f32 row23lo = psimd_interleave_lo_f32(row2, row3);
|
||||
@ -265,22 +293,29 @@ static inline void psimd_transpose4x4_f32(const psimd_f32 row0,
|
||||
*col3 = psimd_concat_hi_f32(row01hi, row23hi);
|
||||
}
|
||||
|
||||
static inline void
|
||||
winograd_f2k3_input_transform(const psimd_f32 d0, const psimd_f32 d1,
|
||||
const psimd_f32 d2, const psimd_f32 d3,
|
||||
psimd_f32 *transform0, psimd_f32 *transform1,
|
||||
psimd_f32 *transform2, psimd_f32 *transform3) {
|
||||
static inline void winograd_f2k3_input_transform(
|
||||
const psimd_f32 d0,
|
||||
const psimd_f32 d1,
|
||||
const psimd_f32 d2,
|
||||
const psimd_f32 d3,
|
||||
psimd_f32* transform0,
|
||||
psimd_f32* transform1,
|
||||
psimd_f32* transform2,
|
||||
psimd_f32* transform3) {
|
||||
*transform0 = d0 - d2;
|
||||
*transform1 = d1 + d2;
|
||||
*transform2 = -d1 + d2;
|
||||
*transform3 = d1 - d3;
|
||||
}
|
||||
|
||||
static inline void
|
||||
winograd_f2k3_kernel_transform(const psimd_f32 g0, const psimd_f32 g1,
|
||||
const psimd_f32 g2, psimd_f32 *transform0,
|
||||
psimd_f32 *transform1, psimd_f32 *transform2,
|
||||
psimd_f32 *transform3) {
|
||||
static inline void winograd_f2k3_kernel_transform(
|
||||
const psimd_f32 g0,
|
||||
const psimd_f32 g1,
|
||||
const psimd_f32 g2,
|
||||
psimd_f32* transform0,
|
||||
psimd_f32* transform1,
|
||||
psimd_f32* transform2,
|
||||
psimd_f32* transform3) {
|
||||
const psimd_f32 const_half = psimd_splat_f32(0.5);
|
||||
const psimd_f32 half_g0_plus_g2 = const_half * (g0 + g2);
|
||||
*transform0 = g0;
|
||||
@ -289,17 +324,23 @@ winograd_f2k3_kernel_transform(const psimd_f32 g0, const psimd_f32 g1,
|
||||
*transform3 = g2;
|
||||
}
|
||||
|
||||
static inline void
|
||||
winograd_f2k3_output_transform(const psimd_f32 m0, const psimd_f32 m1,
|
||||
const psimd_f32 m2, const psimd_f32 m3,
|
||||
psimd_f32 *output0, psimd_f32 *output1) {
|
||||
static inline void winograd_f2k3_output_transform(
|
||||
const psimd_f32 m0,
|
||||
const psimd_f32 m1,
|
||||
const psimd_f32 m2,
|
||||
const psimd_f32 m3,
|
||||
psimd_f32* output0,
|
||||
psimd_f32* output1) {
|
||||
*output0 = m0 + m1 + m2;
|
||||
*output1 = m1 - m2 - m3;
|
||||
}
|
||||
|
||||
void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
const float *kernel, const float *bias,
|
||||
float *output) {
|
||||
void runDepthwise3x3Conv(
|
||||
const DepthwiseArgs& args,
|
||||
const float* input,
|
||||
const float* kernel,
|
||||
const float* bias,
|
||||
float* output) {
|
||||
const psimd_f32 vbias = {0, *bias, 0, 0};
|
||||
const psimd_f32 g0 = psimd_load_f32(kernel);
|
||||
const psimd_f32 g1 = psimd_load_f32(kernel + 3);
|
||||
@ -314,8 +355,8 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
winograd_f2k3_kernel_transform(g0, g1, g2, &w[0], &w[1], &w[2], &w[3]);
|
||||
psimd_transpose4x4_f32(w[0], w[1], w[2], w[3], &w[0], &w[1], &w[2], &w[3]);
|
||||
psimd_f32 wg[4];
|
||||
winograd_f2k3_kernel_transform(w[0], w[1], w[2], &wg[0], &wg[1], &wg[2],
|
||||
&wg[3]);
|
||||
winograd_f2k3_kernel_transform(
|
||||
w[0], w[1], w[2], &wg[0], &wg[1], &wg[2], &wg[3]);
|
||||
|
||||
// Iterate over non-padded output tiles.
|
||||
for (int oth = 0; oth < (args.out_rows + 1) / 2; ++oth) {
|
||||
@ -338,13 +379,18 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
}
|
||||
psimd_f32 wd[4];
|
||||
winograd_f2k3_input_transform(
|
||||
psimd_load_f32(&block[0]), psimd_load_f32(&block[1]),
|
||||
psimd_load_f32(&block[2]), psimd_load_f32(&block[3]), &wd[0], &wd[1],
|
||||
&wd[2], &wd[3]);
|
||||
psimd_transpose4x4_f32(wd[0], wd[1], wd[2], wd[3], &wd[0], &wd[1], &wd[2],
|
||||
&wd[3]);
|
||||
winograd_f2k3_input_transform(wd[0], wd[1], wd[2], wd[3], &wd[0], &wd[1],
|
||||
&wd[2], &wd[3]);
|
||||
psimd_load_f32(&block[0]),
|
||||
psimd_load_f32(&block[1]),
|
||||
psimd_load_f32(&block[2]),
|
||||
psimd_load_f32(&block[3]),
|
||||
&wd[0],
|
||||
&wd[1],
|
||||
&wd[2],
|
||||
&wd[3]);
|
||||
psimd_transpose4x4_f32(
|
||||
wd[0], wd[1], wd[2], wd[3], &wd[0], &wd[1], &wd[2], &wd[3]);
|
||||
winograd_f2k3_input_transform(
|
||||
wd[0], wd[1], wd[2], wd[3], &wd[0], &wd[1], &wd[2], &wd[3]);
|
||||
|
||||
for (int row = 0; row < 4; ++row) {
|
||||
wd[row] = wg[row] * wd[row];
|
||||
@ -352,8 +398,8 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
wd[1] += vbias;
|
||||
psimd_f32 s[4] = {{0}};
|
||||
winograd_f2k3_output_transform(wd[0], wd[1], wd[2], wd[3], &s[0], &s[1]);
|
||||
psimd_transpose4x4_f32(s[0], s[1], s[2], s[3], &s[0], &s[1], &s[2],
|
||||
&s[3]);
|
||||
psimd_transpose4x4_f32(
|
||||
s[0], s[1], s[2], s[3], &s[0], &s[1], &s[2], &s[3]);
|
||||
|
||||
psimd_f32 t0, t1;
|
||||
winograd_f2k3_output_transform(s[0], s[1], s[2], s[3], &t0, &t1);
|
||||
@ -377,12 +423,13 @@ void runDepthwise3x3Conv(const DepthwiseArgs &args, const float *input,
|
||||
#endif
|
||||
|
||||
class Depthwise3x3ConvOp final : public ConvPoolOpBase<CPUContext> {
|
||||
public:
|
||||
public:
|
||||
USE_CONV_POOL_BASE_FUNCTIONS(CPUContext);
|
||||
Depthwise3x3ConvOp(const OperatorDef &operator_def, Workspace *ws)
|
||||
Depthwise3x3ConvOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
|
||||
OPERATOR_NEEDS_FEATURE(this->order_ == StorageOrder::NCHW,
|
||||
"Depthwise3x3ConvOp only supports NCHW order");
|
||||
OPERATOR_NEEDS_FEATURE(
|
||||
this->order_ == StorageOrder::NCHW,
|
||||
"Depthwise3x3ConvOp only supports NCHW order");
|
||||
OPERATOR_NEEDS_FEATURE(this->group_ > 1);
|
||||
OPERATOR_NEEDS_FEATURE(this->kernel_w() == 3);
|
||||
OPERATOR_NEEDS_FEATURE(this->kernel_h() == 3);
|
||||
@ -391,9 +438,9 @@ public:
|
||||
}
|
||||
|
||||
bool RunOnDeviceWithOrderNCHW() override {
|
||||
const Tensor<CPUContext> &X = Input(0);
|
||||
auto &filter = Input(1);
|
||||
Tensor<CPUContext> *Y = Output(0);
|
||||
const Tensor<CPUContext>& X = Input(0);
|
||||
auto& filter = Input(1);
|
||||
Tensor<CPUContext>* Y = Output(0);
|
||||
const int N = X.dim32(0), C = X.dim32(1);
|
||||
CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
|
||||
const int M = filter.dim32(0);
|
||||
@ -423,16 +470,19 @@ public:
|
||||
if (InputSize() != 3 && bias_.size() != M) {
|
||||
// no bias.
|
||||
bias_.Resize(M);
|
||||
math::Set<float, CPUContext>(M, 0.0, bias_.mutable_data<float>(),
|
||||
&context_);
|
||||
math::Set<float, CPUContext>(
|
||||
M, 0.0, bias_.mutable_data<float>(), &context_);
|
||||
}
|
||||
const auto *bias =
|
||||
const auto* bias =
|
||||
InputSize() == 3 ? Input(2).data<float>() : bias_.data<float>();
|
||||
|
||||
auto f = [&](int n, int g) {
|
||||
runDepthwise3x3Conv(args, X.data<float>() + g * IS + n * G * IS,
|
||||
filter.data<float>() + g * 3 * 3, bias + g,
|
||||
Y->mutable_data<float>() + g * OS + n * G * OS);
|
||||
runDepthwise3x3Conv(
|
||||
args,
|
||||
X.data<float>() + g * IS + n * G * IS,
|
||||
filter.data<float>() + g * 3 * 3,
|
||||
bias + g,
|
||||
Y->mutable_data<float>() + g * OS + n * G * OS);
|
||||
};
|
||||
|
||||
Timer t;
|
||||
@ -485,7 +535,7 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
Tensor<CPUContext> bias_;
|
||||
};
|
||||
|
||||
|
@ -171,8 +171,7 @@ void runConv(
|
||||
int group = 1,
|
||||
int planesIn = randInt(1, 6),
|
||||
int planesOut = randInt(1, 6),
|
||||
int n = randInt(1, 2))
|
||||
{
|
||||
int n = randInt(1, 2)) {
|
||||
int h = randInt(20, 100);
|
||||
int w = randInt(20, 100);
|
||||
// This pad restriction is imposed by NNPACK
|
||||
@ -210,12 +209,10 @@ constexpr size_t kIters = 20;
|
||||
TEST(DEPTHWISE3x3, Conv) {
|
||||
for (int i = 0; i < kIters; ++i) {
|
||||
int channel = 2;
|
||||
runConv(
|
||||
3, 3, 1, 1, channel, channel, channel, randInt(1, 2));
|
||||
runConv(3, 3, 1, 1, channel, channel, channel, randInt(1, 2));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
|
@ -81,8 +81,7 @@ void compare(
|
||||
"convolution_transform_strategy", convolutionTransformStrategy));
|
||||
}
|
||||
if (!activation.empty()) {
|
||||
nnpackOpDef.add_arg()->CopyFrom(MakeArgument(
|
||||
"activation", activation));
|
||||
nnpackOpDef.add_arg()->CopyFrom(MakeArgument("activation", activation));
|
||||
}
|
||||
nnpackOpDef.add_arg()->CopyFrom(MakeArgument("stride_h", strideH));
|
||||
nnpackOpDef.add_arg()->CopyFrom(MakeArgument("stride_w", strideW));
|
||||
@ -132,7 +131,6 @@ void compare(
|
||||
EXPECT_NE(nullptr, activationOp.get());
|
||||
}
|
||||
|
||||
|
||||
for (auto i = 0; i < 10; ++i) {
|
||||
EXPECT_TRUE(nnpackOp->Run());
|
||||
}
|
||||
@ -313,7 +311,17 @@ TEST(NNPACK, ConvRelu_1x1s1) {
|
||||
auto outChannels = randInt(1, 8) * group;
|
||||
auto n = 1;
|
||||
runConv(
|
||||
1, 1, 1, 1, group, "DIRECT", inChannels, outChannels, n, "PRECOMPUTE", "Relu");
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
group,
|
||||
"DIRECT",
|
||||
inChannels,
|
||||
outChannels,
|
||||
n,
|
||||
"PRECOMPUTE",
|
||||
"Relu");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,12 @@
|
||||
#
|
||||
# This module finds the Intel Mkl libraries.
|
||||
#
|
||||
# USE_IDEEP : use IDEEP interface
|
||||
# USE_MKLML : use MKLML interface
|
||||
# MKLML_USE_SINGLE_DYNAMIC_LIBRARY : use single dynamic library interface
|
||||
# MKLML_USE_STATIC_LIBS : use static libraries
|
||||
# MKLML_MULTI_THREADED : use multi-threading
|
||||
#
|
||||
# This module sets the following variables:
|
||||
# MKL_FOUND - set to true if a library implementing the CBLAS interface is found
|
||||
# MKL_VERSION - best guess
|
||||
|
Reference in New Issue
Block a user