From 238ceab825d0986d469dbe2d046d245782126d9b Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Mon, 14 Nov 2016 14:58:04 -0800 Subject: [PATCH] fbsync. TODO: check if build files need update. --- LICENSE | 38 +- caffe2/binaries/split_db.cc | 53 +- caffe2/contrib/nccl/cuda_nccl_gpu.cc | 3 +- caffe2/contrib/nnpack/nnpack_ops.cc | 13 +- caffe2/contrib/nnpack/nnpack_ops_test.py | 5 +- caffe2/core/asan.h | 25 + caffe2/core/blob_serialization.cc | 4 +- caffe2/core/blob_serialization.h | 30 +- caffe2/core/blob_test.cc | 27 +- caffe2/core/common.h | 18 + caffe2/core/common_gpu.cc | 75 +- caffe2/core/common_gpu.h | 11 - caffe2/core/context_gpu.cu | 223 +++-- caffe2/core/context_gpu.h | 29 +- caffe2/core/db.h | 4 +- caffe2/core/logging_test.cc | 2 + caffe2/core/net.cc | 32 +- caffe2/core/net.h | 3 +- caffe2/core/net_test.cc | 6 +- caffe2/core/operator.cc | 46 +- caffe2/core/operator.h | 39 +- caffe2/core/operator_test.cc | 37 +- caffe2/core/typeid.cc | 8 + caffe2/core/typeid.h | 8 +- caffe2/core/workspace.cc | 74 +- caffe2/core/workspace.h | 30 +- caffe2/mpi/mpi_gpu_test.cc | 10 +- caffe2/mpi/mpi_python.cc | 12 + caffe2/operators/concat_split_op.h | 6 +- caffe2/operators/conv_transpose_op.cc | 1 + caffe2/operators/conv_transpose_op.h | 4 +- caffe2/operators/conv_transpose_op_impl.h | 23 +- .../operators/conv_transpose_unpool_op_base.h | 4 +- caffe2/operators/counter_ops.cc | 66 +- caffe2/operators/cross_entropy_op.cc | 2 +- caffe2/operators/dataset_ops.cc | 200 +++-- caffe2/operators/distance_op.cc | 12 +- caffe2/operators/elementwise_op_schema.cc | 22 +- caffe2/operators/elu_op.cc | 81 ++ caffe2/operators/elu_op.h | 37 + caffe2/operators/fully_connected_op.h | 7 +- caffe2/operators/fully_connected_op_test.cc | 3 + caffe2/operators/h_softmax_op.cc | 210 ++++- caffe2/operators/h_softmax_op.h | 178 ++-- caffe2/operators/load_save_op.cc | 6 +- caffe2/operators/load_save_op.h | 62 +- caffe2/operators/lp_pool_op.cc | 273 ++++++ caffe2/operators/lp_pool_op.cu | 349 ++++++++ caffe2/operators/metrics_ops.cc | 53 ++ caffe2/operators/metrics_ops.h | 85 ++ caffe2/operators/pack_segments.cc | 10 +- caffe2/operators/packed_fc_op.cc | 64 +- caffe2/operators/partition_ops.cc | 29 +- caffe2/operators/pool_op.cc | 166 ++++ caffe2/operators/prelu_op.cc | 300 +++++++ caffe2/operators/prelu_op.h | 40 + caffe2/operators/softmax_op.cc | 42 +- caffe2/operators/softmax_op.cu | 18 +- caffe2/operators/softmax_op.h | 10 +- caffe2/operators/softmax_shared.cc | 55 ++ caffe2/operators/softmax_shared.h | 19 + caffe2/operators/softmax_with_loss_op.cc | 278 +++++++ caffe2/operators/softmax_with_loss_op.cu | 396 +++++++++ caffe2/operators/softmax_with_loss_op.h | 63 ++ caffe2/operators/softsign_op.cc | 50 ++ caffe2/operators/softsign_op.cu | 25 +- caffe2/operators/spatial_batch_norm_op.cc | 6 +- caffe2/operators/utility_ops.cc | 35 +- caffe2/operators/utility_ops.h | 108 ++- caffe2/operators/workspace_ops.cc | 42 + caffe2/proto/caffe2.proto | 12 +- caffe2/proto/hsm.proto | 3 + caffe2/python/_import_c_extension.py | 12 + caffe2/python/caffe_translator.py | 42 + caffe2/python/cnn.py | 64 +- caffe2/python/context.py | 101 +++ caffe2/python/control.py | 491 +++++++---- caffe2/python/control_test.py | 208 +++-- caffe2/python/convnet_benchmarks.py | 5 +- caffe2/python/core.py | 357 ++++++-- caffe2/python/data_parallel_model.py | 779 ++++++++---------- caffe2/python/data_parallel_model_test.py | 46 +- caffe2/python/dataio.py | 136 ++- caffe2/python/dataio_test.py | 52 ++ caffe2/python/dataset.py | 104 ++- caffe2/python/dyndep.py | 14 +- caffe2/python/experiment_util.py | 10 +- caffe2/python/hsm_test.py | 100 ++- caffe2/python/hsm_util.py | 7 +- caffe2/python/hypothesis_test.py | 35 +- caffe2/python/layer_model_helper.py | 295 +++++++ caffe2/python/layer_model_instantiator.py | 44 + caffe2/python/layers/__init__.py | 27 + caffe2/python/layers/batch_lr_loss.py | 44 + caffe2/python/layers/concat.py | 56 ++ caffe2/python/layers/fc.py | 64 ++ caffe2/python/layers/layers.py | 87 ++ .../python/layers/simple_operator_layers.py | 67 ++ caffe2/python/layers/sparse_lookup.py | 96 +++ caffe2/python/layers/sparse_to_dense.py | 131 +++ caffe2/python/layers/tags.py | 50 ++ caffe2/python/load_save_test.py | 50 +- caffe2/python/model_helper.py | 147 +++- caffe2/python/models/resnet.py | 14 +- caffe2/python/net_builder.py | 251 ++++++ caffe2/python/net_builder_test.py | 82 ++ caffe2/python/net_drawer.py | 8 +- caffe2/python/op/python.py | 27 - caffe2/python/op/python_op.cpp | 206 ----- .../operator_test/activation_ops_test.py | 81 ++ .../operator_test/conv_transpose_test.py | 8 +- .../python/operator_test/counter_ops_test.py | 21 + .../python/operator_test/dataset_ops_test.py | 16 +- caffe2/python/operator_test/matmul_op_test.py | 8 +- caffe2/python/operator_test/mkl_ops_test.py | 10 +- caffe2/python/operator_test/pack_ops_test.py | 25 + caffe2/python/operator_test/pooling_test.py | 13 +- caffe2/python/operator_test/python_op_test.py | 46 ++ .../python/operator_test/reshape_ops_test.py | 27 +- .../python/operator_test/softmax_ops_test.py | 266 ++++++ caffe2/python/pipeline.py | 358 ++++++-- caffe2/python/pybind_state.cc | 169 +++- caffe2/python/pybind_state.h | 105 ++- caffe2/python/pybind_state_gpu.cc | 8 +- .../{op/python_test.py => python_op_test.py} | 16 +- caffe2/python/queue_util.py | 74 +- caffe2/python/schema.py | 272 +++++- caffe2/python/schema_test.py | 42 +- caffe2/python/scope.py | 42 +- caffe2/python/scope_test.py | 83 ++ caffe2/python/session.py | 147 ++++ caffe2/python/session_test.py | 60 ++ caffe2/python/snapshot.py | 263 ++++++ caffe2/python/snapshot_test.py | 95 +++ caffe2/python/task.py | 482 +++++++++++ caffe2/python/timeout_guard.py | 56 ++ caffe2/python/workspace.py | 7 +- caffe2/queue/queue_ops.h | 9 +- caffe2/sgd/adam_op.h | 6 +- caffe2/sgd/adam_op_gpu.cu | 2 +- caffe2/utils/cpu_neon.h | 61 ++ caffe2/utils/fixed_divisor.h | 146 ++++ caffe2/utils/math.h | 22 + caffe2/utils/math_cpu.cc | 184 ++++- caffe2/utils/mkl/sgemm_pack.h | 15 +- caffe2/utils/proto_utils.h | 4 +- caffe2/utils/threadpool/ThreadPool.cc | 231 ++++++ caffe2/utils/threadpool/ThreadPool.h | 143 ++++ caffe2/utils/threadpool/ThreadPoolCommon.h | 26 + caffe2/utils/threadpool/pthreadpool.cc | 169 ++++ caffe2/utils/threadpool/pthreadpool.h | 111 +++ caffe2/utils/threadpool/pthreadpool_impl.cc | 26 + caffe2/utils/threadpool/pthreadpool_impl.h | 30 + 153 files changed, 10718 insertions(+), 1896 deletions(-) create mode 100644 caffe2/core/asan.h create mode 100644 caffe2/operators/elu_op.cc create mode 100644 caffe2/operators/elu_op.h create mode 100644 caffe2/operators/lp_pool_op.cc create mode 100644 caffe2/operators/lp_pool_op.cu create mode 100644 caffe2/operators/metrics_ops.cc create mode 100644 caffe2/operators/metrics_ops.h create mode 100644 caffe2/operators/prelu_op.cc create mode 100644 caffe2/operators/prelu_op.h create mode 100644 caffe2/operators/softmax_shared.cc create mode 100644 caffe2/operators/softmax_shared.h create mode 100644 caffe2/operators/softmax_with_loss_op.cc create mode 100644 caffe2/operators/softmax_with_loss_op.cu create mode 100644 caffe2/operators/softmax_with_loss_op.h create mode 100644 caffe2/operators/workspace_ops.cc create mode 100644 caffe2/python/context.py create mode 100644 caffe2/python/dataio_test.py create mode 100644 caffe2/python/layer_model_helper.py create mode 100644 caffe2/python/layer_model_instantiator.py create mode 100644 caffe2/python/layers/__init__.py create mode 100644 caffe2/python/layers/batch_lr_loss.py create mode 100644 caffe2/python/layers/concat.py create mode 100644 caffe2/python/layers/fc.py create mode 100644 caffe2/python/layers/layers.py create mode 100644 caffe2/python/layers/simple_operator_layers.py create mode 100644 caffe2/python/layers/sparse_lookup.py create mode 100644 caffe2/python/layers/sparse_to_dense.py create mode 100644 caffe2/python/layers/tags.py create mode 100644 caffe2/python/net_builder.py create mode 100644 caffe2/python/net_builder_test.py delete mode 100644 caffe2/python/op/python.py delete mode 100644 caffe2/python/op/python_op.cpp create mode 100644 caffe2/python/operator_test/activation_ops_test.py create mode 100644 caffe2/python/operator_test/python_op_test.py create mode 100644 caffe2/python/operator_test/softmax_ops_test.py rename caffe2/python/{op/python_test.py => python_op_test.py} (91%) create mode 100644 caffe2/python/scope_test.py create mode 100644 caffe2/python/session.py create mode 100644 caffe2/python/session_test.py create mode 100644 caffe2/python/snapshot.py create mode 100644 caffe2/python/snapshot_test.py create mode 100644 caffe2/python/task.py create mode 100644 caffe2/python/timeout_guard.py create mode 100644 caffe2/utils/cpu_neon.h create mode 100644 caffe2/utils/fixed_divisor.h create mode 100644 caffe2/utils/threadpool/ThreadPool.cc create mode 100644 caffe2/utils/threadpool/ThreadPool.h create mode 100644 caffe2/utils/threadpool/ThreadPoolCommon.h create mode 100644 caffe2/utils/threadpool/pthreadpool.cc create mode 100644 caffe2/utils/threadpool/pthreadpool.h create mode 100644 caffe2/utils/threadpool/pthreadpool_impl.cc create mode 100644 caffe2/utils/threadpool/pthreadpool_impl.h diff --git a/LICENSE b/LICENSE index 9e3fe9c34272..f0f0527fad0f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,8 @@ COPYRIGHT +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + All contributions by Google: Copyright (c) 2015 Google Inc. All rights reserved. @@ -13,7 +16,7 @@ Copyright(c) 2013, 2014, 2015, the respective contributors All rights reserved. All other contributions: -Copyright(c) 2015, the respective contributors +Copyright(c) 2015, 2016 the respective contributors All rights reserved. Caffe2 uses a copyright model similar to Caffe: each contributor holds @@ -124,36 +127,3 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *** end zmqhpp license *** - -Some part of the caffe2 code (specifically, third_party/cnmem) comes from the -open-source cnmem code under the 2-clause BSD license. The cnmem license is -as follows: -*** begin cnmem license *** -/* ********************************************************************** - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of NVIDIA CORPORATION nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY - * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * ********************************************************************** */ -*** end cnmem license *** diff --git a/caffe2/binaries/split_db.cc b/caffe2/binaries/split_db.cc index 2889f6f1b279..6f17ab3eb1de 100644 --- a/caffe2/binaries/split_db.cc +++ b/caffe2/binaries/split_db.cc @@ -11,35 +11,40 @@ CAFFE2_DEFINE_int(splits, 0, "The number of splits."); CAFFE2_DEFINE_string(db_type, "", "The db type."); CAFFE2_DEFINE_int(batch_size, 1000, "The write batch size."); -using caffe2::db::Cursor; -using caffe2::db::DB; -using caffe2::db::Transaction; +namespace caffe2 { -int main(int argc, char** argv) { - caffe2::GlobalInit(&argc, &argv); +static int Split(int argc, char** argv) { + GlobalInit(&argc, &argv); - std::unique_ptr in_db(caffe2::db::CreateDB( - caffe2::FLAGS_db_type, caffe2::FLAGS_input_db, caffe2::db::READ)); - std::unique_ptr cursor(in_db->NewCursor()); + CAFFE_ENFORCE(FLAGS_input_db.size(), "Must specify --input_db=/path/to/db."); + CAFFE_ENFORCE(FLAGS_splits > 0, "Must specify a nonnegative split number."); + CAFFE_ENFORCE(FLAGS_db_type.size(), "Must specify --db_type=[a db type]."); - CHECK_GT(caffe2::FLAGS_splits, 0) << "Must specify the number of splits."; - std::vector > out_dbs; - std::vector > transactions; - for (int i = 0; i < caffe2::FLAGS_splits; ++i) { - out_dbs.push_back( - std::unique_ptr(caffe2::db::CreateDB( - caffe2::FLAGS_db_type, - caffe2::FLAGS_input_db + "_split_" + caffe2::to_string(i), - caffe2::db::NEW))); + unique_ptr in_db( + db::CreateDB(FLAGS_db_type, FLAGS_input_db, db::READ)); + CAFFE_ENFORCE(in_db != nullptr, "Cannot open input db: ", FLAGS_input_db); + unique_ptr cursor(in_db->NewCursor()); + // This usually won't happen, but FWIW. + CAFFE_ENFORCE( + cursor != nullptr, "Cannot obtain cursor for input db: ", FLAGS_input_db); + + vector> out_dbs; + vector> transactions; + for (int i = 0; i < FLAGS_splits; ++i) { + out_dbs.push_back(unique_ptr(db::CreateDB( + FLAGS_db_type, FLAGS_input_db + "_split_" + to_string(i), db::NEW))); + CAFFE_ENFORCE(out_dbs.back().get(), "Cannot create output db #", i); transactions.push_back( - std::unique_ptr(out_dbs[i]->NewTransaction())); + unique_ptr(out_dbs[i]->NewTransaction())); + CAFFE_ENFORCE( + transactions.back().get(), "Cannot get transaction for output db #", i); } int count = 0; for (; cursor->Valid(); cursor->Next()) { - transactions[count % caffe2::FLAGS_splits]->Put(cursor->key(), cursor->value()); - if (++count % caffe2::FLAGS_batch_size == 0) { - for (int i = 0; i < caffe2::FLAGS_splits; ++i) { + transactions[count % FLAGS_splits]->Put(cursor->key(), cursor->value()); + if (++count % FLAGS_batch_size == 0) { + for (int i = 0; i < FLAGS_splits; ++i) { transactions[i]->Commit(); } LOG(INFO) << "Split " << count << " items so far."; @@ -48,3 +53,9 @@ int main(int argc, char** argv) { LOG(INFO) << "A total of " << count << " items processed."; return 0; } + +} // namespace caffe2 + +int main(int argc, char** argv) { + return caffe2::Split(argc, argv); +} diff --git a/caffe2/contrib/nccl/cuda_nccl_gpu.cc b/caffe2/contrib/nccl/cuda_nccl_gpu.cc index 41f81eb21d38..43c82660261e 100644 --- a/caffe2/contrib/nccl/cuda_nccl_gpu.cc +++ b/caffe2/contrib/nccl/cuda_nccl_gpu.cc @@ -30,7 +30,8 @@ class NCCLContext { // get stream priorities int lo_pri, hi_pri; CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri)); - CUDA_CHECK(cudaStreamCreateWithPriority(&streams_[i], cudaStreamNonBlocking, hi_pri)); + CUDA_CHECK(cudaStreamCreateWithPriority( + &streams_[i], cudaStreamNonBlocking, hi_pri)); CUDA_CHECK(cudaEventCreateWithFlags( &events_[i], cudaEventDefault | cudaEventDisableTiming)); } diff --git a/caffe2/contrib/nnpack/nnpack_ops.cc b/caffe2/contrib/nnpack/nnpack_ops.cc index d3931ebb5317..1a68be3d740e 100644 --- a/caffe2/contrib/nnpack/nnpack_ops.cc +++ b/caffe2/contrib/nnpack/nnpack_ops.cc @@ -76,6 +76,8 @@ class NNPACKConvOp final : public ConvPoolOpBase { this->order_ == StorageOrder::NCHW, "NNPack only supports NCHW order. Please consider adding " "TransposeOp with axes=[0, 3, 1, 2] before NNPack Conv."); + OPERATOR_NEEDS_FEATURE( + __builtin_cpu_supports("avx2"), "NNPack requires AVX2"); } bool RunOnDeviceWithOrderNCHW() override; @@ -101,8 +103,7 @@ bool NNPACKConvOp::RunOnDeviceWithOrderNCHW() { CAFFE_ENFORCE(filter.dim32(1) == C, ""); CAFFE_ENFORCE(filter.dim32(2) == this->kernel_h_, ""); CAFFE_ENFORCE(filter.dim32(3) == this->kernel_w_, ""); - CAFFE_ENFORCE(bias.ndim() == 1, ""); - CAFFE_ENFORCE(bias.dim32(0) == M, ""); + CAFFE_ENFORCE(bias.size() == M, ""); ConvPoolOpBase::SetOutputSize(X, Y, filter.dim32(0)); if (N > 1) { // NNPack only supports stride = 1 when doing batch feedforward @@ -200,6 +201,8 @@ class NNPACKMaxPoolOp final : public ConvPoolOpBase { OPERATOR_NEEDS_FEATURE( this->pad_b_ == 0, "NNPack Pooling differs from Caffe2 Pooling when pad > 0!"); + OPERATOR_NEEDS_FEATURE( + __builtin_cpu_supports("avx2"), "NNPack requires AVX2"); } bool RunOnDeviceWithOrderNCHW() override; @@ -215,12 +218,6 @@ bool NNPACKMaxPoolOp::RunOnDeviceWithOrderNCHW() { auto* Y = Output(0); CAFFE_ENFORCE(X.ndim() == 4, ""); const int H = X.dim32(2), W = X.dim32(3); - CAFFE_ENFORCE( - H % 2 == 0, - "NNPack MaxPool differs from Caffe2 when Input Size is not even!"); - CAFFE_ENFORCE( - W % 2 == 0, - "NNPack MaxPool differs from Caffe2 when Input Size is not even!"); ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); std::vector pads( {this->pad_t_, this->pad_b_, this->pad_l_, this->pad_r_}); diff --git a/caffe2/contrib/nnpack/nnpack_ops_test.py b/caffe2/contrib/nnpack/nnpack_ops_test.py index 5316c4b3f493..d2b4feb568c6 100644 --- a/caffe2/contrib/nnpack/nnpack_ops_test.py +++ b/caffe2/contrib/nnpack/nnpack_ops_test.py @@ -43,7 +43,7 @@ def has_avx2(): @unittest.skipIf(not has_avx2(), "NNPACK requires AVX2") class NNPackOpsTest(hu.HypothesisTestCase): - @given(stride=st.integers(1, 1), + @given(stride=st.integers(1, 3), pad=st.integers(0, 2), kernel=st.integers(3, 5), size=st.integers(5, 10), @@ -54,6 +54,9 @@ class NNPackOpsTest(hu.HypothesisTestCase): input_channels, output_channels, batch_size): assume(stride <= kernel) + if stride != 1: + assume(batch_size == 1) + X = np.random.rand( batch_size, input_channels, size, size).astype(np.float32) - 0.5 w = np.random.rand( diff --git a/caffe2/core/asan.h b/caffe2/core/asan.h new file mode 100644 index 000000000000..c4526df3e7d0 --- /dev/null +++ b/caffe2/core/asan.h @@ -0,0 +1,25 @@ +#pragma once + +// Detect address sanitizer as some stuff doesn't work with it + +#undef CAFFE2_ASAN_ENABLED + +// for clang +#if defined(__has_feature) +#if ((__has_feature(address_sanitizer))) +#define CAFFE2_ASAN_ENABLED 1 +#endif +#endif + +// for gcc +#if defined(__SANITIZE_ADDRESS__) +#if __SANITIZE_ADDRESS__ +#if !defined(CAFFE2_ASAN_ENABLED) +#define CAFFE2_ASAN_ENABLED 1 +#endif +#endif +#endif + +#if !defined(CAFFE2_ASAN_ENABLED) +#define CAFFE2_ASAN_ENABLED 0 +#endif diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index c99effc43ae8..4f221385b5fd 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -56,7 +56,7 @@ class StringDeserializer : public BlobDeserializerBase { namespace { // We can't use DeviceType_Name because of a protobuf-lite constraint. -std::string tensorDeviceTypeName(const DeviceType& d) { +std::string tensorDeviceTypeName(const int32_t& d) { switch (d) { case CPU: return "TensorCPU"; @@ -84,7 +84,7 @@ std::string Blob::Serialize(const string& name) const { std::stringstream data; std::mutex mutex; BlobSerializerBase::SerializationAcceptor acceptor = - [&data, &mutex](const std::string& name, const std::string& blob) { + [&data, &mutex](const std::string&, const std::string& blob) { std::lock_guard guard(mutex); data << blob; }; diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h index 6a448c903d19..d7dcef879f20 100644 --- a/caffe2/core/blob_serialization.h +++ b/caffe2/core/blob_serialization.h @@ -199,16 +199,19 @@ void TensorSerializer::SerializeWithChunkSize( std::vector> futures; #endif - for (size_t chunkBegin = 0; chunkBegin < tensor.size(); + // Serialize whole vector. If vector is empty, it's shape still needs to be + // serialized in empty proto + for (size_t chunkBegin = 0; + chunkBegin < std::max(tensor.size(), static_cast(1)); chunkBegin += chunk_size) { - auto task = [&](size_t chunkBegin) { + auto task = [&](size_t chunkStart) { BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type(kTensorBlobType); TensorProto& proto = *blob_proto.mutable_tensor(); proto.set_name(name); this->Serialize( - tensor, name, blob_proto.mutable_tensor(), chunkBegin, chunk_size); + tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size); acceptor(name, blob_proto.SerializeAsString()); }; #ifndef __ANDROID__ @@ -237,20 +240,21 @@ void TensorSerializer::Serialize( const Tensor& input, const string& name, TensorProto* proto_ptr, size_t chunkBegin, int32_t chunkSize) { CAFFE_ENFORCE( - chunkBegin < input.size(), + chunkBegin <= input.size(), "Chunk begin is out of tensor: ", chunkBegin, ' ', input.size()); + if (chunkBegin + chunkSize > input.size()) { + chunkSize = input.size() - chunkBegin; + } + CAFFE_ENFORCE( - input.raw_data(), + input.raw_data() || chunkSize == 0, "The input does not have data input yet. This is probably because you " "created a tensor of non-zero shape but never filled its data via " "mutable_data() calls. This means that it makes no sense to serialize " "the tensor content."); - if (chunkBegin + chunkSize > input.size()) { - chunkSize = input.size() - chunkBegin; - } TensorProto& proto = *proto_ptr; proto.mutable_segment()->set_begin(chunkBegin); @@ -261,6 +265,8 @@ void TensorSerializer::Serialize( } const TensorProto::DataType data_type = TypeMetaToDataType(input.meta()); proto.set_data_type(data_type); + StoreDeviceDetail(input, &proto); + // A lot of copypaste is error prone. Should we create a macro for this? switch (data_type) { case TensorProto_DataType_FLOAT: @@ -354,7 +360,6 @@ void TensorSerializer::Serialize( // Note: we intentially do not provide "default:" so if any new data types // are added, the compiler should warn the user to add the case here. } - StoreDeviceDetail(input, &proto); } template @@ -378,11 +383,6 @@ bool TensorDeserializer::Deserialize( } tensor->Resize(dims); - // Safety check for zero-sized tensors: no copy needed. - if (tensor->size() == 0) { - return true; - } - int64_t chunkBegin = 0; auto chunkEnd = tensor->size(); if (proto.has_segment()) { @@ -390,7 +390,7 @@ bool TensorDeserializer::Deserialize( chunkEnd = proto.segment().end(); } CAFFE_ENFORCE( - 0 <= chunkBegin && chunkBegin < chunkEnd && chunkEnd <= tensor->size(), + 0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->size(), "Invalid chunk ", chunkBegin, ' ', diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc index 18fae4f22f55..af36f930f726 100644 --- a/caffe2/core/blob_test.cc +++ b/caffe2/core/blob_test.cc @@ -408,7 +408,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) { TEST(TensorTest, TensorSerialization_##TypeParam) { \ Blob blob; \ TensorCPU* tensor = blob.GetMutable(); \ - tensor->Resize(2, 3); \ + tensor->Resize(2, 3); \ for (int i = 0; i < 6; ++i) { \ tensor->mutable_data()[i] = static_cast(i); \ } \ @@ -437,6 +437,31 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) { EXPECT_EQ( \ tensor->data()[i], new_tensor.data()[i]); \ } \ + } \ + \ + TEST(EmptyTensorTest, TensorSerialization_##TypeParam) { \ + Blob blob; \ + TensorCPU* tensor = blob.GetMutable(); \ + tensor->Resize(0, 3); \ + tensor->mutable_data(); \ + string serialized = blob.Serialize("test"); \ + BlobProto proto; \ + CHECK(proto.ParseFromString(serialized)); \ + EXPECT_EQ(proto.name(), "test"); \ + EXPECT_EQ(proto.type(), "Tensor"); \ + EXPECT_TRUE(proto.has_tensor()); \ + const TensorProto& tensor_proto = proto.tensor(); \ + EXPECT_EQ( \ + tensor_proto.data_type(), \ + TypeMetaToDataType(TypeMeta::Make())); \ + EXPECT_EQ(tensor_proto.field_name##_size(), 0); \ + Blob new_blob; \ + EXPECT_TRUE(new_blob.Deserialize(serialized)); \ + EXPECT_TRUE(new_blob.IsType()); \ + const TensorCPU& new_tensor = blob.Get(); \ + EXPECT_EQ(new_tensor.ndim(), 2); \ + EXPECT_EQ(new_tensor.dim(0), 0); \ + EXPECT_EQ(new_tensor.dim(1), 3); \ } TEST_SERIALIZATION_WITH_TYPE(bool, int32_data) diff --git a/caffe2/core/common.h b/caffe2/core/common.h index bffb82c04425..d1651ee9b613 100644 --- a/caffe2/core/common.h +++ b/caffe2/core/common.h @@ -9,6 +9,10 @@ #include #include +#ifdef __APPLE__ +#include +#endif + namespace caffe2 { // Note(Yangqing): NVCC does not play well with unordered_map on some platforms, @@ -44,6 +48,20 @@ private: \ classname& operator=(const classname&) = delete #endif +// Define enabled when building for iOS or Android devices +#if !defined(CAFFE2_MOBILE) +#if defined(__ANDROID__) +#define CAFFE2_ANDROID 1 +#define CAFFE2_MOBILE 1 +#elif (defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define CAFFE2_IOS 1 +#define CAFFE2_MOBILE 1 +#else +#define CAFFE2_MOBILE 0 +#endif // ANDROID / IOS +#endif // CAFFE2_MOBILE + // make_unique is a C++14 feature. If we don't have 14, we will emulate // its behavior. This is copied from folly/Memory.h #if __cplusplus >= 201402L || \ diff --git a/caffe2/core/common_gpu.cc b/caffe2/core/common_gpu.cc index 74137bd66bdb..4272b0eb56ac 100644 --- a/caffe2/core/common_gpu.cc +++ b/caffe2/core/common_gpu.cc @@ -1,6 +1,7 @@ #include "caffe2/core/common_gpu.h" #include +#include #include #include "caffe2/core/init.h" @@ -9,6 +10,14 @@ namespace caffe2 { int NumCudaDevices() { + if (getenv("CAFFE2_DEBUG_CUDA_INIT_ORDER")) { + static bool first = true; + if (first) { + first = false; + std::cerr << "DEBUG: caffe2::NumCudaDevices() invoked for the first time" + << std::endl; + } + } static int count = -1; if (count < 0) { auto err = cudaGetDeviceCount(&count); @@ -28,10 +37,18 @@ int NumCudaDevices() { "have a cuda gpu."; count = 0; break; + case cudaErrorUnknown: + LOG(ERROR) << "Found an unknown error - this may be due to an " + "incorrectly set up environment, e.g. changing env " + "variable CUDA_VISIBLE_DEVICES after program start. " + "I will set the available devices to be zero."; + count = 0; + break; default: LOG(FATAL) << "Unexpected error from cudaGetDeviceCount(). Did you run " "some cuda functions before calling NumCudaDevices() " - "that might have already set an error?"; + "that might have already set an error? Error: " + << err; } } return count; @@ -193,60 +210,4 @@ const char* curandGetErrorString(curandStatus_t error) { // To suppress compiler warning. return "Unrecognized curand error string"; } - -bool Caffe2InitializeCuda(int*, char***) { - static bool g_initialization_function_called = false; - if (g_initialization_function_called == true) { - VLOG(1) << "Initialization already called. Ignoring duplicated calls."; - return true; - } - g_initialization_function_called = true; - // If the current run does not have any cuda devices, do nothing. - if (!HasCudaGPU()) { - VLOG(1) << "No cuda gpu present. Skipping."; - return true; - } - // Check if the number of GPUs matches the expected compile-time max number - // of GPUs. - CHECK_LE(NumCudaDevices(), CAFFE2_COMPILE_TIME_MAX_GPUS) - << "Number of CUDA devices on the machine is larger than the compiled " - "max number of gpus expected (" - << CAFFE2_COMPILE_TIME_MAX_GPUS - << "). Increase that and recompile the caffe binary."; - // Save the current device so we can restore it after moving across - // different devices. - int init_device; - CUDA_CHECK(cudaGetDevice(&init_device)); - - for (int i = 0; i < NumCudaDevices(); ++i) { - auto err = cudaSetDevice(i); - if (err != cudaSuccess) { - LOG(WARNING) - << "Cannot use device " << i - << "due to the following error: " << cudaGetErrorString(err); - continue; - } - // Enable peer access. - for (int j = 0; j < NumCudaDevices(); ++j) { - if (i == j) continue; - int can_access; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, i, j)); - if (can_access) { - VLOG(1) << "Enabling peer access from " << i << " to " << j; - // Note: just for future reference, the 0 here is not a gpu id, it is - // a reserved flag for cudaDeviceEnablePeerAccess that should always be - // zero currently. - CUDA_CHECK(cudaDeviceEnablePeerAccess(j, 0)); - } - } - } - // Restore the current device. - CUDA_CHECK(cudaSetDevice(init_device)); - return true; -} - -REGISTER_CAFFE2_INIT_FUNCTION(Caffe2InitializeCuda, - &Caffe2InitializeCuda, - "Enable cuda for caffe2."); - } // namespace caffe2 diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h index 088755ae5326..ed2ca0e83ce5 100644 --- a/caffe2/core/common_gpu.h +++ b/caffe2/core/common_gpu.h @@ -108,17 +108,6 @@ const char* cublasGetErrorString(cublasStatus_t error); */ const char* curandGetErrorString(curandStatus_t error); -/** - * Caffe2's CUDA initialization function. - * - * This is going to be run once when caffe2's GlobalInit() function is called. - * If you have an initialization function that depends on CUDA's initialization - * first, you can call this function inside your init function - this will - * ensure that CUDA is initialized before any of your custom initialization is - * carried out. This function is NOT thread safe. - */ -bool Caffe2InitializeCuda(); - // CUDA: various checks for different function calls. #define CUDA_CHECK(condition) \ do { \ diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu index 90e28f8cfdf8..b7c07801c3ac 100644 --- a/caffe2/core/context_gpu.cu +++ b/caffe2/core/context_gpu.cu @@ -1,10 +1,12 @@ #include +#include #include #include #include "cub/util_allocator.cuh" #include "cnmem.h" +#include "caffe2/core/asan.h" #include "caffe2/core/context_gpu.h" #include "caffe2/core/init.h" #include "caffe2/core/logging.h" @@ -48,66 +50,76 @@ CAFFE_KNOWN_TYPE(Tensor); thread_local ThreadLocalCUDAObjects CUDAContext::cuda_objects_; +// TODO(jiayq): these variables shouldn't be currently accessed during static +// initialization. We should consider moving them to a Mayer's singleton to +// be totally safe against SIOF. + // Static global variables for setting up the memory pool. CudaMemoryPoolType g_cuda_memory_pool_type; -bool g_memory_allocation_already_called = false; // For cnmem allocator -vector g_cnmem_available_for_device(NumCudaDevices(), false); +vector g_cnmem_available_for_device; // For cub allocator unique_ptr g_cub_allocator; - CudaMemoryPoolType GetCudaMemoryPoolType() { return g_cuda_memory_pool_type; } -void* CUDAContext::New(size_t nbytes) { - g_memory_allocation_already_called = true; - void* ptr = nullptr; - switch (g_cuda_memory_pool_type) { - case CudaMemoryPoolType::NONE: - CUDA_CHECK(cudaMalloc(&ptr, nbytes)); - return ptr; - case CudaMemoryPoolType::CNMEM: - CAFFE_ENFORCE( - g_cnmem_available_for_device[GetCurrentGPUID()], - "Trying to allocate on device ", GetCurrentGPUID(), - " but cnmem pool is not set up for it."); - CNMEM_CHECK(cnmemMalloc(&ptr, nbytes, nullptr)); - return ptr; - case CudaMemoryPoolType::CUB: - CUDA_CHECK(g_cub_allocator->DeviceAllocate(&ptr, nbytes)); - return ptr; - } - return nullptr; -} +/////////////////////////////////////////////////////////////////////////////// +// A wrapper to allow us to lazily initialize all cuda environments that Caffe +// uses. This gets done the first time a caffe2::CUDAContext::New() gets called +// which is probably the decisive indication that this caffe2 run is going to +// use GPUs. We avoid cuda initialization with core/init.h functionalities so +// that we have minimal resource impact in case we will need to run multiple +// caffe2 instances on a GPU machine. +/////////////////////////////////////////////////////////////////////////////// -void CUDAContext::Delete(void* ptr) { - switch (g_cuda_memory_pool_type) { - case CudaMemoryPoolType::NONE: { - // If memory pool is not set up, use simple cudaFree. - cudaError_t error = cudaFree(ptr); - // For some reason, in Python runtime we sometimes delete a data pointer - // after the cuda runtime exits - this is odd but is probably caused by - // a static workspace that pycaffe2 uses, and the destruction got - // entangled in some race condition. Anyway, since cuda runtime is exiting - // anyway, we will not need to worry about memory leak, so we basically - // ignore it. This is definitely not ideal but works for now. - if (error != cudaSuccess && error != cudaErrorCudartUnloading) { - LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " - << cudaGetErrorString(error); - } - break; } - case CudaMemoryPoolType::CNMEM: - CNMEM_CHECK(cnmemFree(ptr, nullptr)); - break; - case CudaMemoryPoolType::CUB: - CUDA_CHECK(g_cub_allocator->DeviceFree(ptr)); - break; +static void Caffe2InitializeCuda() { + // If the current run does not have any cuda devices, do nothing. + if (!HasCudaGPU()) { + VLOG(1) << "No cuda gpu present. Skipping."; + return; } + // Check if the number of GPUs matches the expected compile-time max number + // of GPUs. + CHECK_LE(NumCudaDevices(), CAFFE2_COMPILE_TIME_MAX_GPUS) + << "Number of CUDA devices on the machine is larger than the compiled " + "max number of gpus expected (" + << CAFFE2_COMPILE_TIME_MAX_GPUS + << "). Increase that and recompile the caffe binary."; + // Save the current device so we can restore it after moving across + // different devices. + int init_device; + CUDA_CHECK(cudaGetDevice(&init_device)); + + for (int i = 0; i < NumCudaDevices(); ++i) { + auto err = cudaSetDevice(i); + if (err != cudaSuccess) { + LOG(WARNING) + << "Cannot use device " << i + << "due to the following error: " << cudaGetErrorString(err); + continue; + } + // Enable peer access. + for (int j = 0; j < NumCudaDevices(); ++j) { + if (i == j) continue; + int can_access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, i, j)); + if (can_access) { + VLOG(1) << "Enabling peer access from " << i << " to " << j; + // Note: just for future reference, the 0 here is not a gpu id, it is + // a reserved flag for cudaDeviceEnablePeerAccess that should always be + // zero currently. + CUDA_CHECK(cudaDeviceEnablePeerAccess(j, 0)); + } + } + } + // Restore the current device. + CUDA_CHECK(cudaSetDevice(init_device)); } static void SetUpCNMEM() { + g_cnmem_available_for_device.assign(NumCudaDevices(), false); VLOG(1) << "Setting up cnmem memory pool."; vector device_ids; // If the cnmem gpus are not set, set up all gpus. @@ -184,42 +196,28 @@ static void SetUpCub() { VLOG(1) << "Done setting up cub memory pool."; } -// Global initializtion function to set up the cuda memory pool during -// construction time. -bool Caffe2SetCUDAMemoryPool(int*, char***) { - if (!HasCudaGPU()) { - VLOG(1) << "No GPU present. I won't set up cuda memory pool"; - return true; - } - if (g_memory_allocation_already_called) { - LOG(ERROR) << "Caffe2SetCUDAMemoryPool should always be called before " - "any CUDAContext::New() calls are made."; - return false; - } +static void Caffe2SetCUDAMemoryPool() { if (FLAGS_caffe2_cuda_memory_pool == "" || FLAGS_caffe2_cuda_memory_pool == "none") { g_cuda_memory_pool_type = CudaMemoryPoolType::NONE; - return true; } else if (FLAGS_caffe2_cuda_memory_pool == "cnmem") { // sets up cnmem. g_cuda_memory_pool_type = CudaMemoryPoolType::CNMEM; SetUpCNMEM(); - return true; } else if (FLAGS_caffe2_cuda_memory_pool == "cub") { // Sets up cub. g_cuda_memory_pool_type = CudaMemoryPoolType::CUB; SetUpCub(); - return true; + } else { + CAFFE_THROW("Unrecognized cuda memory pool type: ", + FLAGS_caffe2_cuda_memory_pool); } - LOG(ERROR) << "Unrecognized cuda memory pool type: " - << FLAGS_caffe2_cuda_memory_pool; - return false; } // An initialization function that sets the CPU side to use pinned cpu // allocator. -bool Caffe2UsePinnedCPUAllocator(int*, char***) { -#ifdef __SANITIZE_ADDRESS__ +void Caffe2UsePinnedCPUAllocator() { +#if CAFFE2_ASAN_ENABLED // Note(jiayq): for more details, see // https://github.com/google/sanitizers/issues/629 LOG(WARNING) << "There are known issues between address sanitizer and " @@ -227,22 +225,99 @@ bool Caffe2UsePinnedCPUAllocator(int*, char***) { "memory allocation in asan mode. If you are expecting any " "behavior that depends on asan, be advised that it is not " "turned on."; - return true; #else if (!HasCudaGPU()) { VLOG(1) << "No GPU present. I won't use pinned allocator then."; - return true; } VLOG(1) << "Caffe2 gpu: setting CPUAllocator to PinnedCPUAllocator."; SetCPUAllocator(new PinnedCPUAllocator()); - return true; #endif } -REGISTER_CAFFE2_INIT_FUNCTION(Caffe2SetCUDAMemoryPool, - &Caffe2SetCUDAMemoryPool, - "Sets up the cuda memory pool."); -REGISTER_CAFFE2_INIT_FUNCTION(Caffe2UsePinnedCPUAllocator, - &Caffe2UsePinnedCPUAllocator, - "Make the CPU side use pinned memory."); +// Caffe2CudaInitializerHelper is a minimal struct whose sole purpose is to +// detect the first hint that this Caffe2 run is going to use GPU: either +// CUDAContext is initialized or CUDAContext::New is called. It then runs +// all the related cuda initialization functions. +namespace { +struct Caffe2CudaInitializerHelper { + Caffe2CudaInitializerHelper() { + // We cannot use bool because nvcc changes bool to __nv_bool which does + // not have a std::atomic instantiation. + static std::atomic first_call(1); + if (first_call.fetch_and((char)0)) { + Caffe2InitializeCuda(); + Caffe2SetCUDAMemoryPool(); + Caffe2UsePinnedCPUAllocator(); + } + } +}; +} // namespace + +CUDAContext::CUDAContext(const int gpu_id) + : gpu_id_(gpu_id == -1 ? GetDefaultGPUID() : gpu_id) + , random_seed_(math::randomNumberSeed()) { + static Caffe2CudaInitializerHelper g_cuda_initializer_; +} + +CUDAContext::CUDAContext(const DeviceOption& option) + : gpu_id_(option.has_cuda_gpu_id() ? + option.cuda_gpu_id() : GetDefaultGPUID()), + random_seed_(option.has_random_seed() ? + option.random_seed() : math::randomNumberSeed()) { + static Caffe2CudaInitializerHelper g_cuda_initializer_; + DCHECK_EQ(option.device_type(), CUDA); +} + + +void* CUDAContext::New(size_t nbytes) { + // A one-time caffe2 cuda initializer. + static Caffe2CudaInitializerHelper g_cuda_initializer_; + void* ptr = nullptr; + switch (g_cuda_memory_pool_type) { + case CudaMemoryPoolType::NONE: + CUDA_CHECK(cudaMalloc(&ptr, nbytes)); + return ptr; + case CudaMemoryPoolType::CNMEM: { + auto gpuId = GetCurrentGPUID(); + CAFFE_ENFORCE( + gpuId < g_cnmem_available_for_device.size() && + g_cnmem_available_for_device[gpuId], + "Trying to allocate on device ", + gpuId, + " but cnmem pool is not set up for it."); + CNMEM_CHECK(cnmemMalloc(&ptr, nbytes, nullptr)); + return ptr; + } + case CudaMemoryPoolType::CUB: + CUDA_CHECK(g_cub_allocator->DeviceAllocate(&ptr, nbytes)); + return ptr; + } + return nullptr; +} + +void CUDAContext::Delete(void* ptr) { + switch (g_cuda_memory_pool_type) { + case CudaMemoryPoolType::NONE: { + // If memory pool is not set up, use simple cudaFree. + cudaError_t error = cudaFree(ptr); + // For some reason, in Python runtime we sometimes delete a data pointer + // after the cuda runtime exits - this is odd but is probably caused by + // a static workspace that pycaffe2 uses, and the destruction got + // entangled in some race condition. Anyway, since cuda runtime is exiting + // anyway, we will not need to worry about memory leak, so we basically + // ignore it. This is definitely not ideal but works for now. + if (error != cudaSuccess && error != cudaErrorCudartUnloading) { + LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " + << cudaGetErrorString(error); + } + break; } + case CudaMemoryPoolType::CNMEM: + CNMEM_CHECK(cnmemFree(ptr, nullptr)); + break; + case CudaMemoryPoolType::CUB: + CUDA_CHECK(g_cub_allocator->DeviceFree(ptr)); + break; + } +} + } // namespace caffe2 diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h index 11f2e4bd76c3..4e99244e7ded 100644 --- a/caffe2/core/context_gpu.h +++ b/caffe2/core/context_gpu.h @@ -44,7 +44,20 @@ struct PinnedCPUAllocator final : CPUAllocator { return data; } void Delete(void* data) override { - CUDA_CHECK(cudaFreeHost(data)); + // Caffe2 uses a lazy way to figure out if one is actually going to use GPUs + // or not. If a CUDAContext::New() call is made, inside the CUDAContext + // function we will switch the cpu side allocator to a PinnedCPUAllocator. + // But, if one calls CPUContext::New() before any cuda allocations, + // PinnedCPUAllocator can still delete the corresponding memory. + cudaError_t err = cudaFreeHost(data); + if (err == cudaErrorInvalidValue) { + free(data); + // Calling cudaGetLastError will reset the cuda error. + cudaGetLastError(); + } else { + // For all other errors, still do a cuda check. + CUDA_CHECK(err); + } } }; @@ -89,18 +102,8 @@ class ThreadLocalCUDAObjects { class CUDAContext final { public: // The default cuda context constructor. - explicit CUDAContext(const int gpu_id = -1) - : gpu_id_(gpu_id == -1 ? GetDefaultGPUID() : gpu_id) - , random_seed_(math::randomNumberSeed()) { - } - - explicit CUDAContext(const DeviceOption& option) - : gpu_id_(option.has_cuda_gpu_id() ? - option.cuda_gpu_id() : GetDefaultGPUID()), - random_seed_(option.has_random_seed() ? - option.random_seed() : math::randomNumberSeed()) { - DCHECK_EQ(option.device_type(), CUDA); - } + explicit CUDAContext(const int gpu_id = -1); + explicit CUDAContext(const DeviceOption& option); ~CUDAContext() { if (curand_generator_) { diff --git a/caffe2/core/db.h b/caffe2/core/db.h index 9d92c890a218..c4130170f8b5 100644 --- a/caffe2/core/db.h +++ b/caffe2/core/db.h @@ -238,9 +238,7 @@ class DBReader { private: void MoveToBeginning() const { - if (cursor_->SupportsSeek()) { - cursor_->SeekToFirst(); - } + cursor_->SeekToFirst(); for (auto s = 0; s < shard_id_; s++) { cursor_->Next(); CAFFE_ENFORCE( diff --git a/caffe2/core/logging_test.cc b/caffe2/core/logging_test.cc index cce709e69680..9d4a36db2ab8 100644 --- a/caffe2/core/logging_test.cc +++ b/caffe2/core/logging_test.cc @@ -64,11 +64,13 @@ TEST(LoggingTest, EnforceShowcase) { WRAP_AND_PRINT(CAFFE_ENFORCE_THAT(Equals(one * two + three, three * two))); } +#if GTEST_HAS_DEATH_TEST TEST(LoggingDeathTest, TestEnforceUsingFatal) { bool kTrue = true; std::swap(FLAGS_caffe2_use_fatal_for_enforce, kTrue); EXPECT_DEATH(CAFFE_ENFORCE(false, "This goes fatal."), ""); std::swap(FLAGS_caffe2_use_fatal_for_enforce, kTrue); } +#endif } // namespace caffe2 diff --git a/caffe2/core/net.cc b/caffe2/core/net.cc index fabfc45947f9..9d9f639b01c8 100644 --- a/caffe2/core/net.cc +++ b/caffe2/core/net.cc @@ -181,15 +181,19 @@ DAGNetBase::ExecutionChains computeChains( CAFFE_DEFINE_REGISTRY(NetRegistry, NetBase, const NetDef&, Workspace*); NetBase::NetBase(const NetDef& def, Workspace* /* unused */) - : external_input_(def.external_input().begin(), - def.external_input().end()), - external_output_(def.external_output().begin(), - def.external_output().end()) { + : external_input_(def.external_input().begin(), def.external_input().end()), + external_output_( + def.external_output().begin(), + def.external_output().end()), + name_(def.name()) { // Go through the operators and make sure that blobs are correctly made. std::set known_blobs( external_input_.begin(), external_input_.end()); std::set remaining_output( external_output_.begin(), external_output_.end()); + for (const auto& blob : known_blobs) { + remaining_output.erase(blob); + } for (const OperatorDef& op : def.op()) { for (const string& in : op.input()) { if (!known_blobs.count(in)) { @@ -249,22 +253,14 @@ SimpleNet::SimpleNet(const NetDef& net_def, Workspace* ws) OperatorDef temp_def(operator_def); temp_def.mutable_device_option()->CopyFrom(net_def.device_option()); operators_.emplace_back(CreateOperator(temp_def, ws)); - CAFFE_ENFORCE( - operators_.back() != nullptr, - "Cannot create operator for def: ", - ProtoDebugString(temp_def)); } else { operators_.emplace_back(CreateOperator(operator_def, ws)); - CAFFE_ENFORCE( - operators_.back() != nullptr, - "Cannot create operator for def: ", - ProtoDebugString(operator_def)); } } } bool SimpleNet::Run() { - VLOG(1) << "Running net."; + VLOG(1) << "Running net " << name_; for (auto& op : operators_) { VLOG(1) << "Running operator " << op->def().name() << "(" << op->def().type() << ")."; @@ -278,7 +274,7 @@ bool SimpleNet::Run() { } bool SimpleNet::RunAsync() { - VLOG(1) << "Running net."; + VLOG(1) << "Running net " << name_; for (auto& op : operators_) { VLOG(1) << "Running operator " << op->def().name() << "(" << op->def().type() << ")."; @@ -385,16 +381,8 @@ DAGNetBase::DAGNetBase(const NetDef& net_def, Workspace* ws) OperatorDef temp_def(op_def); temp_def.mutable_device_option()->CopyFrom(net_def.device_option()); operator_nodes_[idx].operator_ = CreateOperator(temp_def, ws); - CAFFE_ENFORCE( - operator_nodes_[idx].operator_ != nullptr, - "Cannot create operator for def: ", - ProtoDebugString(temp_def)); } else { operator_nodes_[idx].operator_ = CreateOperator(op_def, ws); - CAFFE_ENFORCE( - operator_nodes_[idx].operator_ != nullptr, - "Cannot create operator for def: ", - ProtoDebugString(op_def)); } // Check the inputs, and set up parents if necessary. This addressese the // read after write case. diff --git a/caffe2/core/net.h b/caffe2/core/net.h index e292375ceeb1..6b05c8a312d7 100644 --- a/caffe2/core/net.h +++ b/caffe2/core/net.h @@ -63,6 +63,7 @@ class NetBase { protected: vector external_input_; vector external_output_; + string name_; DISABLE_COPY_AND_ASSIGN(NetBase); }; @@ -112,7 +113,7 @@ class DAGNetBase : public NetBase { // It checks out one ready-to-run operator from the job queue, runs it, // notifies all its children, and for any children that is ready, enqueues // it to the job queue. - virtual void WorkerFunction(); + void WorkerFunction(); vector TEST_Benchmark( const int warmup_runs, const int main_runs, diff --git a/caffe2/core/net_test.cc b/caffe2/core/net_test.cc index a10531da65db..20834299a1f0 100644 --- a/caffe2/core/net_test.cc +++ b/caffe2/core/net_test.cc @@ -153,7 +153,7 @@ TEST(NetTest, ChainingForDifferentDevices) { output: "out" type: "NetTestDummy" device_option { - device_type: CUDA + device_type: 1 } } op { @@ -161,7 +161,7 @@ TEST(NetTest, ChainingForDifferentDevices) { output: "out2" type: "NetTestDummy" device_option { - device_type: CUDA + device_type: 1 } } op { @@ -169,7 +169,7 @@ TEST(NetTest, ChainingForDifferentDevices) { output: "out3" type: "NetTestDummy" device_option { - device_type: CUDA + device_type: 1 cuda_gpu_id: 1 } } diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc index 387d87a88b8f..c29e41d86e7b 100644 --- a/caffe2/core/operator.cc +++ b/caffe2/core/operator.cc @@ -33,23 +33,20 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws) namespace { unique_ptr TryCreateOperator( const string& key, const OperatorDef& operator_def, Workspace* ws) { + auto type = operator_def.device_option().device_type(); + CAFFE_ENFORCE( + gDeviceTypeRegistry()->count(type), + "Device type ", + type, + " not registered."); + OperatorRegistry* registry = gDeviceTypeRegistry()->at(type); + VLOG(1) << "Creating operator with device type " << type; try { - switch (operator_def.device_option().device_type()) { - case CPU: - VLOG(1) << "Creating CPU operator " << key; - return CPUOperatorRegistry()->Create(key, operator_def, ws); - case CUDA: - VLOG(1) << "Creating CUDA operator " << key; - return CUDAOperatorRegistry()->Create(key, operator_def, ws); - default: - LOG(FATAL) << "Unknown device type: " - << operator_def.device_option().device_type(); - return nullptr; - } + return registry->Create(key, operator_def, ws); } catch (const UnsupportedOperatorFeature& err) { VLOG(1) << "Operator " << operator_def.type() - << " with engine does not support the requested feature. Msg: " - << err.what() << ". Proto is: " << ProtoDebugString(operator_def); + << " does not support the requested feature. Msg: " << err.what() + << ". Proto is: " << ProtoDebugString(operator_def); return nullptr; } } @@ -94,23 +91,36 @@ unique_ptr CreateOperator( // Lastly, if the engine does not work here, try using the default engine. auto op = TryCreateOperator(operator_def.type(), operator_def, ws); - if (!op) { - LOG(ERROR) << "Cannot create op from def: " - << ProtoDebugString(operator_def); - } + CAFFE_ENFORCE( + op, + "Cannot create operator of type '", + operator_def.type(), + "'. Verify that implementation for the corresponding device exist. It " + "might also happen if the binary is not linked with the operator " + "implementation code. If Python frontend is used it might happen if " + "dyndep.InitOpsLibrary call is missing. Operator def: ", + ProtoDebugString(operator_def)); return op; } +std::map* gDeviceTypeRegistry() { + static std::map g_device_type_registry; + return &g_device_type_registry; +} + CAFFE_DEFINE_REGISTRY( CPUOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); +CAFFE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry); + CAFFE_DEFINE_REGISTRY( CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); +CAFFE_REGISTER_DEVICE_TYPE(DeviceType::CUDA, CUDAOperatorRegistry); CAFFE_DEFINE_REGISTRY( GradientRegistry, diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 0f5fcace6354..898d3c2a16c9 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -26,22 +26,22 @@ class OperatorBase { virtual ~OperatorBase() {} // Parameter getters. You can use these to get the arguments that you want. - inline bool HasArgument(const string& name) { + inline bool HasArgument(const string& name) const { return arg_helper_.HasArgument(name); } // Functions that deal with arguments. Basically, this allows us to map an // argument name to a specific type of argument that we are trying to access. template - inline T GetSingleArgument(const string& name, const T& default_value) { + inline T GetSingleArgument(const string& name, const T& default_value) const { return arg_helper_.GetSingleArgument(name, default_value); } template - inline bool HasSingleArgumentOfType(const string& name) { + inline bool HasSingleArgumentOfType(const string& name) const { return arg_helper_.HasSingleArgumentOfType(name); } template - inline vector GetRepeatedArgument(const string& name) { + inline vector GetRepeatedArgument(const string& name) const { return arg_helper_.GetRepeatedArgument(name); } @@ -298,6 +298,36 @@ struct DispatchHelper, ExtraArgs...> { } }; +// The device type registry. This works in two phases: +// (1) gDeviceTypeRegistry() maps the device types values to the actual operator +// registry function. +// (2) Then, one can call the operator registry function to further create the +// operators. +typedef Registry + OperatorRegistry; +typedef Registry* ( + *RegistryFunction)(); +std::map* gDeviceTypeRegistry(); + +struct DeviceTypeRegisterer { + explicit DeviceTypeRegisterer(int32_t type, RegistryFunction func) { + if (gDeviceTypeRegistry()->count(type)) { + std::cerr << "Device type " << type + << "registered twice. This should not happen. Did you have " + "duplicated numbers assigned to different devices?"; + std::exit(1); + } + // Calling the registry function to get the actual registry pointer. + gDeviceTypeRegistry()->emplace(type, func()); + } +}; + +#define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \ + namespace { \ + static DeviceTypeRegisterer CAFFE_ANONYMOUS_VARIABLE( \ + DeviceType)(type, ®istry_function); \ + } + // The operator registry. Since we are not expecting a great number of devices, // we will simply have an if-then type command and allocate the actual // generation to device-specific registerers. @@ -365,6 +395,7 @@ class UnsupportedOperatorFeature : public std::exception { } // Creates an operator with the given operator definition. +// Throws on error and never returns nullptr unique_ptr CreateOperator( const OperatorDef& operator_def, Workspace* ws); diff --git a/caffe2/core/operator_test.cc b/caffe2/core/operator_test.cc index bd875d6c4852..78e64683e950 100644 --- a/caffe2/core/operator_test.cc +++ b/caffe2/core/operator_test.cc @@ -61,6 +61,10 @@ REGISTER_CPU_OPERATOR_WITH_ENGINE(JustTest, BAR, JustTestAndDoesConstruct); REGISTER_CUDA_OPERATOR(JustTest, JustTest); REGISTER_CPU_OPERATOR(ThrowException, ThrowException); +TEST(OperatorTest, DeviceTypeRegistryWorks) { + EXPECT_EQ(gDeviceTypeRegistry()->count(DeviceType::CPU), 1); +} + TEST(OperatorTest, RegistryWorks) { OperatorDef op_def; Workspace ws; @@ -132,22 +136,9 @@ TEST(OperatorTest, TestParameterAccess) { op_def.set_type("JustTest"); op_def.add_input("input"); op_def.add_output("output"); - { - Argument* arg = op_def.add_arg(); - arg->set_name("arg0"); - arg->set_f(0.1); - } - { - Argument* arg = op_def.add_arg(); - arg->set_name("arg1"); - arg->add_ints(1); - arg->add_ints(2); - } - { - Argument* arg = op_def.add_arg(); - arg->set_name("arg2"); - arg->set_s("argstring"); - } + AddArgument("arg0", 0.1, &op_def); + AddArgument>("arg1", vector{1, 2}, &op_def); + AddArgument("arg2", "argstring", &op_def); EXPECT_NE(ws.CreateBlob("input"), nullptr); OperatorBase op(op_def, &ws); EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); @@ -165,17 +156,14 @@ TEST(OperatorTest, CannotAccessParameterWithWrongType) { op_def.set_type("JustTest"); op_def.add_input("input"); op_def.add_output("output"); - { - Argument* arg = op_def.add_arg(); - arg->set_name("arg0"); - arg->set_f(0.1); - } + AddArgument("arg0", 0.1, &op_def); EXPECT_NE(ws.CreateBlob("input"), nullptr); OperatorBase op(op_def, &ws); EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); ASSERT_THROW(op.GetSingleArgument("arg0", 0), EnforceNotMet); } +#if GTEST_HAS_DEATH_TEST TEST(OperatorDeathTest, DISABLED_CannotAccessRepeatedParameterWithWrongType) { OperatorDef op_def; Workspace ws; @@ -183,11 +171,7 @@ TEST(OperatorDeathTest, DISABLED_CannotAccessRepeatedParameterWithWrongType) { op_def.set_type("JustTest"); op_def.add_input("input"); op_def.add_output("output"); - { - Argument* arg = op_def.add_arg(); - arg->set_name("arg0"); - arg->add_floats(0.1); - } + AddArgument>("arg0", vector{0.1}, &op_def); EXPECT_NE(ws.CreateBlob("input"), nullptr); OperatorBase op(op_def, &ws); auto args = op.GetRepeatedArgument("arg0"); @@ -196,6 +180,7 @@ TEST(OperatorDeathTest, DISABLED_CannotAccessRepeatedParameterWithWrongType) { EXPECT_DEATH(op.GetRepeatedArgument("arg0"), "Argument does not have the right field: expected ints"); } +#endif TEST(OperatorTest, TestDefaultValue) { OperatorDef op_def; diff --git a/caffe2/core/typeid.cc b/caffe2/core/typeid.cc index c7163a43abd3..cf5ed498536c 100644 --- a/caffe2/core/typeid.cc +++ b/caffe2/core/typeid.cc @@ -24,6 +24,14 @@ string Demangle(const char* name) { return name; } +string GetExceptionString(const std::exception& e) { +#ifdef __GXX_RTTI + return Demangle(typeid(e).name()) + ": " + e.what(); +#else + return string("Exception (no RTTI available): ") + e.what(); +#endif // __GXX_RTTI +} + namespace { // This single registerer exists solely for us to be able to name a TypeMeta // for unintializied blob. You should not use this struct yourself - it is diff --git a/caffe2/core/typeid.h b/caffe2/core/typeid.h index 4d68c7a7f373..f5ee84a4b7a6 100644 --- a/caffe2/core/typeid.h +++ b/caffe2/core/typeid.h @@ -27,6 +27,10 @@ std::set& gRegisteredTypeNames(); // A utility function to demangle a function name. string Demangle(const char* name); +// A utility function to return an exception string by prepending its exception +// type before its what() content. +string GetExceptionString(const std::exception& e); + template struct TypeNameRegisterer { explicit TypeNameRegisterer(CaffeTypeId id) { @@ -166,7 +170,7 @@ class TypeMeta { * is generated during run-time. Do NOT serialize the id for storage. */ template - static CaffeTypeId Id(); + [[gnu::visibility("default")]] static CaffeTypeId Id(); /** * Returns the item size of the type. This is equivalent to sizeof(T). @@ -184,7 +188,7 @@ class TypeMeta { template static const char* Name() { #ifdef __GXX_RTTI - static string name = Demangle(typeid(T).name()); + static const string name = Demangle(typeid(T).name()); return name.c_str(); #else // __GXX_RTTI return "(RTTI disabled, cannot show name)"; diff --git a/caffe2/core/workspace.cc b/caffe2/core/workspace.cc index 9cd006d0da3d..f9eefcc7975b 100644 --- a/caffe2/core/workspace.cc +++ b/caffe2/core/workspace.cc @@ -10,6 +10,12 @@ #include "caffe2/core/timer.h" #include "caffe2/proto/caffe2.pb.h" +CAFFE2_DEFINE_bool( + caffe2_handle_executor_threads_exceptions, + false, + "If used we will handle exceptions in executor threads. " + "This avoids SIGABRT but may cause process to deadlock"); + namespace caffe2 { namespace { @@ -36,19 +42,33 @@ std::function getContinuationTest( "Must not specify num_iter if should_stop_blob is set"); } - if (!step.has_should_stop_blob()) { + if (!step.has_should_stop_blob()) { // control by iteration + CAFFE_ENFORCE(!step.has_only_once(), "not supported"); int64_t iterations = step.has_num_iter() ? step.num_iter() : 1; VLOG(1) << "Will execute step " << step.name() << " for " << iterations << " iterations."; return [=](int64_t i) { return i < iterations; }; - } else { - VLOG(1) << "Will execute step " << step.name() << " until stopped by blob " - << step.should_stop_blob(); - return [](int64_t i) { return true; }; + } else { // control by signal blob + bool onlyOnce = step.has_only_once() && step.only_once(); + VLOG(1) << "Will execute step" << step.name() << (onlyOnce ? " once " : "") + << " until stopped by blob " << step.should_stop_blob(); + if (onlyOnce) { + return [](int64_t i) { return i == 0; }; + } else { + return [](int64_t i) { return true; }; + } } }; } // namespace +vector Workspace::LocalBlobs() const { + vector names; + for (auto& entry : blob_map_) { + names.push_back(entry.first); + } + return names; +} + vector Workspace::Blobs() const { vector names; for (auto& entry : blob_map_) { @@ -188,6 +208,20 @@ bool Workspace::RunPlan(const PlanDef& plan, return true; } +#if CAFFE2_MOBILE +ThreadPool* Workspace::GetThreadPool() { + std::lock_guard guard(thread_pool_creation_mutex_); + + if (!thread_pool_) { + auto numThreads = std::thread::hardware_concurrency(); + LOG(INFO) << "Constructing thread pool with " << numThreads << " threads"; + thread_pool_.reset(new ThreadPool(numThreads)); + } + + return thread_pool_.get(); +} +#endif // CAFFE2_MOBILE + namespace { struct Reporter { @@ -272,8 +306,8 @@ bool Workspace::ExecuteStepRecursive( if (!step.concurrent_substeps() || step.substep().size() <= 1) { VLOG(1) << "Executing step " << step.name() << " iteration " << iter; - auto substepShouldContinue = [&, externalShouldContinue](int64_t iter) { - return externalShouldContinue(iter); + auto substepShouldContinue = [&, externalShouldContinue](int64_t it) { + return externalShouldContinue(it); }; for (auto& ss : step.substep()) { @@ -288,11 +322,11 @@ bool Workspace::ExecuteStepRecursive( std::atomic next_substep{0}; std::atomic got_failure{false}; - auto substepShouldContinue = [&, externalShouldContinue](int64_t iter) { - return !got_failure && externalShouldContinue(iter); + auto substepShouldContinue = [&, externalShouldContinue](int64_t it) { + return !got_failure && externalShouldContinue(it); }; std::mutex exception_mutex; - std::exception_ptr first_exception; + string first_exception; auto worker = [&]() { while (true) { int substep_id = next_substep++; @@ -306,10 +340,18 @@ bool Workspace::ExecuteStepRecursive( } } catch (const std::exception& ex) { std::lock_guard guard(exception_mutex); - if (!first_exception) { - first_exception = std::current_exception(); + if (!first_exception.size()) { + first_exception = GetExceptionString(ex); + LOG(ERROR) << "Parallel worker exception:\n" << first_exception; } got_failure = true; + if (!FLAGS_caffe2_handle_executor_threads_exceptions) { + // In complex plans other threads might get stuck if another + // one fails. So we let exception to go out of thread which + // causes SIGABRT. In local setup one might use this flag + // in order to use Python debugger after a failure + throw; + } } } }; @@ -322,9 +364,11 @@ bool Workspace::ExecuteStepRecursive( thread.join(); } if (got_failure) { - LOG(ERROR) << "One of the workers died with an unhandled exception"; - if (first_exception != nullptr) { - std::rethrow_exception(first_exception); + LOG(ERROR) << "One of the workers failed."; + if (first_exception.size()) { + CAFFE_THROW( + "One of the workers died with an unhandled exception ", + first_exception); } return false; } diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h index ad43296bc5fa..0aa7bd4ea1fa 100644 --- a/caffe2/core/workspace.h +++ b/caffe2/core/workspace.h @@ -1,17 +1,26 @@ #ifndef CAFFE2_CORE_WORKSPACE_H_ #define CAFFE2_CORE_WORKSPACE_H_ +#include "caffe2/core/common.h" + +#ifndef CAFFE2_MOBILE +#error "mobile build state not defined" +#endif + #include #include +#include #include #include #include "caffe2/core/blob.h" -#include "caffe2/core/common.h" #include "caffe2/core/registry.h" #include "caffe2/core/net.h" #include "caffe2/proto/caffe2.pb.h" #include "caffe2/utils/signal_handler.h" +#if CAFFE2_MOBILE +#include "caffe2/utils/threadpool/ThreadPool.h" +#endif // CAFFE2_MOBILE namespace caffe2 { @@ -73,6 +82,12 @@ class Workspace { : root_folder_(root_folder), shared_(shared) {} ~Workspace() {} + /** + * Return list of blobs owned by this Workspace, not including blobs + * shared from parent workspace. + */ + vector LocalBlobs() const; + /** * Return a list of blob names. This may be a bit slow since it will involve * creation of multiple temp variables. For best performance, simply use @@ -149,6 +164,15 @@ class Workspace { bool RunPlan(const PlanDef& plan_def, ShouldContinue should_continue = StopOnSignal{}); +#if CAFFE2_MOBILE + /* + * Returns a CPU threadpool instace for parallel execution of + * work. The threadpool is created lazily; if no operators use it, + * then no threadpool will be created. + */ + ThreadPool* GetThreadPool(); +#endif + // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference // between RunNet and RunNetOnce lies in the fact that RunNet allows you to // have a persistent net object, while RunNetOnce creates a net and discards @@ -167,6 +191,10 @@ class Workspace { NetMap net_map_; string root_folder_ = "."; Workspace* shared_ = nullptr; +#if CAFFE2_MOBILE + std::unique_ptr thread_pool_; + std::mutex thread_pool_creation_mutex_; +#endif // CAFFE2_MOBILE DISABLE_COPY_AND_ASSIGN(Workspace); }; diff --git a/caffe2/mpi/mpi_gpu_test.cc b/caffe2/mpi/mpi_gpu_test.cc index 315416481f65..14ea3eb8e518 100644 --- a/caffe2/mpi/mpi_gpu_test.cc +++ b/caffe2/mpi/mpi_gpu_test.cc @@ -42,7 +42,7 @@ const char kBcastNet[] = R"NET( } } device_option { - device_type: CUDA + device_type: 1 } )NET"; @@ -106,7 +106,7 @@ const char kReduceNet[] = R"NET( } } device_option { - device_type: CUDA + device_type: 1 } )NET"; @@ -174,7 +174,7 @@ const char kMPIAllgatherNet[] = R"NET( type: "Allgather" } device_option { - device_type: CUDA + device_type: 1 } )NET"; @@ -239,7 +239,7 @@ const char kMPIAllreduceNet[] = R"NET( engine: "MPI" } device_option { - device_type: CUDA + device_type: 1 } )NET"; @@ -303,7 +303,7 @@ const char kInPlaceMPIAllreduceNet[] = R"NET( engine: "MPI" } device_option { - device_type: CUDA + device_type: 1 } )NET"; diff --git a/caffe2/mpi/mpi_python.cc b/caffe2/mpi/mpi_python.cc index b82ef32fc37f..c811f2e30602 100644 --- a/caffe2/mpi/mpi_python.cc +++ b/caffe2/mpi/mpi_python.cc @@ -30,6 +30,18 @@ PYBIND11_PLUGIN(mpi) { // with `-quiet` and skipping the finalize call. MPI_Finalize(); }); + m.def("Broadcast", [](py::bytes in) -> py::bytes { + std::string str = in; + auto comm = GlobalMPIComm(); + auto length = str.length(); + MPI_Bcast(&length, sizeof(length), MPI_CHAR, 0, comm); + auto ptr = caffe2::make_unique(length); + if (MPICommRank(comm) == 0) { + memcpy(ptr.get(), str.data(), str.length()); + } + MPI_Bcast(ptr.get(), length, MPI_CHAR, 0, comm); + return std::string(ptr.get(), length); + }); return m.ptr(); } diff --git a/caffe2/operators/concat_split_op.h b/caffe2/operators/concat_split_op.h index f4dff9105b88..e3067c614333 100644 --- a/caffe2/operators/concat_split_op.h +++ b/caffe2/operators/concat_split_op.h @@ -184,9 +184,11 @@ bool ConcatOp::RunOnDevice() { ". The input tensors can only have different dimensions " "along the axis = ", axis_, + " <", Input(0).dims(), - " vs ", - Input(j).dims()); + "> vs <", + Input(j).dims(), + ">."); } } diff --git a/caffe2/operators/conv_transpose_op.cc b/caffe2/operators/conv_transpose_op.cc index 67b2f98ea467..3dac8ae4086e 100644 --- a/caffe2/operators/conv_transpose_op.cc +++ b/caffe2/operators/conv_transpose_op.cc @@ -5,6 +5,7 @@ namespace caffe2 { namespace { REGISTER_CPU_OPERATOR(ConvTranspose, ConvTransposeOp); + REGISTER_CPU_OPERATOR( ConvTransposeGradient, ConvTransposeGradientOp); diff --git a/caffe2/operators/conv_transpose_op.h b/caffe2/operators/conv_transpose_op.h index 68b8b4132f0f..c4cf21f10c9c 100644 --- a/caffe2/operators/conv_transpose_op.h +++ b/caffe2/operators/conv_transpose_op.h @@ -10,7 +10,7 @@ namespace caffe2 { template class ConvTransposeOp final : public ConvTransposeUnpoolBase { public: - USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS; + USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context); ConvTransposeOp(const OperatorDef& operator_def, Workspace* ws) : ConvTransposeUnpoolBase(operator_def, ws) {} @@ -28,7 +28,7 @@ class ConvTransposeOp final : public ConvTransposeUnpoolBase { template class ConvTransposeGradientOp final : public ConvTransposeUnpoolBase { public: - USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS; + USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context); ConvTransposeGradientOp(const OperatorDef& operator_def, Workspace* ws) : ConvTransposeUnpoolBase(operator_def, ws) {} diff --git a/caffe2/operators/conv_transpose_op_impl.h b/caffe2/operators/conv_transpose_op_impl.h index 64a573cd1783..e45167586261 100644 --- a/caffe2/operators/conv_transpose_op_impl.h +++ b/caffe2/operators/conv_transpose_op_impl.h @@ -43,14 +43,17 @@ bool ConvTransposeOp::RunOnDeviceWithOrderNCHW() { const int input_image_size = H * W; const int output_image_size = Y->dim32(2) * Y->dim32(3); +#ifndef __ARM_NEON__ if (bias_multiplier_.size() != output_image_size) { bias_multiplier_.Resize(vector(1, output_image_size)); math::Set( - output_image_size, - static_cast(1), - bias_multiplier_.template mutable_data(), - &context_); + output_image_size, + static_cast(1), + bias_multiplier_.template mutable_data(), + &context_); } +#endif // !__ARM_NEON__ + const T* Xdata = X.template data(); T* Ydata = Y->template mutable_data(); @@ -71,6 +74,7 @@ bool ConvTransposeOp::RunOnDeviceWithOrderNCHW() { 0, col_buffer_data, &context_); + // Col2im math::Col2im( col_buffer_data, @@ -89,7 +93,9 @@ bool ConvTransposeOp::RunOnDeviceWithOrderNCHW() { stride_w_, Ydata, &context_); + // Bias term +#ifndef __ARM_NEON__ math::Gemm( CblasNoTrans, CblasNoTrans, @@ -102,6 +108,15 @@ bool ConvTransposeOp::RunOnDeviceWithOrderNCHW() { 1, Ydata, &context_); +#else + math::BiasCHW( + bias.template data(), + C, + output_image_size, + Ydata, + &context_); +#endif // !__ARM_NEON__ + Xdata += M * H * W; Ydata += Y->size() / Y->dim32(0); } diff --git a/caffe2/operators/conv_transpose_unpool_op_base.h b/caffe2/operators/conv_transpose_unpool_op_base.h index 675c1500be00..158b22128e4c 100644 --- a/caffe2/operators/conv_transpose_unpool_op_base.h +++ b/caffe2/operators/conv_transpose_unpool_op_base.h @@ -187,8 +187,8 @@ class ConvTransposeUnpoolBase : public Operator { } }; -#define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS \ - USE_OPERATOR_CONTEXT_FUNCTIONS; \ +#define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \ + USE_OPERATOR_FUNCTIONS(Context); \ using ConvTransposeUnpoolBase::pad_t_; \ using ConvTransposeUnpoolBase::pad_b_; \ using ConvTransposeUnpoolBase::pad_l_; \ diff --git a/caffe2/operators/counter_ops.cc b/caffe2/operators/counter_ops.cc index 511fa97c5a15..b6082eb371d9 100644 --- a/caffe2/operators/counter_ops.cc +++ b/caffe2/operators/counter_ops.cc @@ -1,9 +1,67 @@ #include "counter_ops.h" +#include "caffe2/core/blob_serialization.h" + namespace caffe2 { namespace { +namespace { +/** + * @brief CounterSerializer is the serializer for Counter type. + * + * CounterSerializer takes in a blob that contains a Counter, and serializes + * it into a BlobProto protocol buffer. At the moment only int64_t counters are + * supported (since it's the only once that is really used). + * + */ +class CounterSerializer : public BlobSerializerBase { + public: + CounterSerializer() {} + ~CounterSerializer() {} -// TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp + void Serialize( + const Blob& blob, + const string& name, + SerializationAcceptor acceptor) override { + CAFFE_ENFORCE(blob.IsType>>()); + + BlobProto blob_proto; + blob_proto.set_name(name); + blob_proto.set_type("std::unique_ptr>"); + TensorProto& proto = *blob_proto.mutable_tensor(); + proto.set_name(name); + proto.set_data_type(TensorProto_DataType_INT64); + proto.add_dims(1); + proto.add_int64_data( + blob.template Get>>()->retrieve()); + acceptor(name, blob_proto.SerializeAsString()); + } +}; + +/** + * @brief CounterDeserializer is the deserializer for Counters. + * + */ +class CounterDeserializer : public BlobDeserializerBase { + public: + bool Deserialize(const BlobProto& proto, Blob* blob) override { + auto tensorProto = proto.tensor(); + CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1, "Unexpected size of dims"); + CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1, "Unexpected value of dims"); + CAFFE_ENFORCE_EQ( + tensorProto.data_type(), + TensorProto_DataType_INT64, + "Only int64_t counters supported"); + CAFFE_ENFORCE_EQ( + tensorProto.int64_data_size(), 1, "Unexpected size of data"); + *blob->GetMutable>>() = + caffe2::make_unique>(tensorProto.int64_data(0)); + return true; + } +}; +} + +// TODO(jiayq): deprecate these ops & consolidate them with +// IterOp/AtomicIterOp REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp); REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp); @@ -80,5 +138,11 @@ SHOULD_NOT_DO_GRADIENT(RetrieveCount); } // namespace CAFFE_KNOWN_TYPE(std::unique_ptr>); +REGISTER_BLOB_SERIALIZER( + (TypeMeta::Id>>()), + CounterSerializer); +REGISTER_BLOB_DESERIALIZER( + std::unique_ptr>, + CounterDeserializer); } // namespace caffe2 diff --git a/caffe2/operators/cross_entropy_op.cc b/caffe2/operators/cross_entropy_op.cc index d5f84a5882b5..5f2b5d1d74d0 100644 --- a/caffe2/operators/cross_entropy_op.cc +++ b/caffe2/operators/cross_entropy_op.cc @@ -89,7 +89,7 @@ bool SigmoidCrossEntropyWithLogitsGradientOp::RunOnDevice() { auto in_idx = 0; for (int i = 0; i < outer_size; ++i) { auto g_factor = -g_ptr[i] / inner_size; - for (int i = 0; i < inner_size; ++i) { + for (int j = 0; j < inner_size; ++j) { out_ptr[in_idx] = g_factor * sigmoid_xent_backward(logits_ptr[in_idx], targets_ptr[in_idx]); ++in_idx; diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index f5a421969406..34217856e9e6 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -2,6 +2,7 @@ #include #include #include +#include "caffe2/core/blob_serialization.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/utils/string_utils.h" @@ -402,10 +403,8 @@ class SortAndShuffleOp : public Operator { bool RunOnDevice() override { auto& cursor = OperatorBase::Input>(0); CAFFE_ENFORCE(InputSize() == cursor->it.fields().size() + 1); - CAFFE_ENFORCE( - -1 <= sort_by_field_idx_ && - sort_by_field_idx_ < cursor->it.fields().size()); - + CAFFE_ENFORCE(-1 <= sort_by_field_idx_); + CAFFE_ENFORCE(cursor->it.fields().size() - sort_by_field_idx_ > 0); int size; if (sort_by_field_idx_ != -1) { size = Input(sort_by_field_idx_ + 1).dims()[0]; @@ -415,9 +414,13 @@ class SortAndShuffleOp : public Operator { CAFFE_ENFORCE( batch_size_ > 0 && shuffle_size_ > 0 && - 0 < batch_size_ * shuffle_size_ && batch_size_ * shuffle_size_ <= size); - int num_batch = size / batch_size_; + 0 < batch_size_ * shuffle_size_); + // adjust shuffle_size_ if it is too large + if (batch_size_ * shuffle_size_ > size) { + shuffle_size_ = size / batch_size_; + } + int num_batch = size / batch_size_; auto* out = Output(0); out->Resize(size); auto* out_data = out->mutable_data(); @@ -709,56 +712,52 @@ class CollectTensorOp final : public Operator { } bool RunOnDevice() override { - // TENSOR_VECTOR_IN is enforced inplace with TENSOR_VECTOR_OUT - TensorVectorPtr& tensorVector = - *OperatorBase::Output>(TENSOR_VECTOR_OUT); - - auto* position_out = Output(POSITION_OUT); - const auto& tensor = Input(TENSOR_TO_COLLECT); - int pos = -1; - if (InputSize() >= 3) { - CAFFE_ENFORCE(0 == Input(POSITION_IN).ndim()); - pos = Input(POSITION_IN).template data()[0]; + if (numVisited_ < numToCollect_) { + // append + pos = numVisited_; } else { - if (numVisited_ < numToCollect_) { - // append - pos = tensorVector->size(); - } else { + auto& gen = context_.RandGenerator(); + // uniform between [0, numVisited_] + std::uniform_int_distribution uniformDist(0, numVisited_); + pos = uniformDist(gen); + if (pos >= numToCollect_) { + // discard + pos = -1; + } + } + + for (int i = 0; i < OutputSize(); ++i) { + // TENSOR_VECTOR_IN is enforced inplace with TENSOR_VECTOR_OUT + TensorVectorPtr& tensorVector = + *OperatorBase::Output>(i); + + if (numVisited_ >= numToCollect_) { CAFFE_ENFORCE( tensorVector->size() == numToCollect_, "TensorVecotor size = ", tensorVector->size(), " is different from numToCollect = ", numToCollect_); - auto& gen = context_.RandGenerator(); - // uniform between [0, numVisited_] - std::uniform_int_distribution uniformDist(0, numVisited_); - pos = uniformDist(gen); - if (pos >= numToCollect_) { - // discard - pos = -1; - } + } + + const auto& tensor = Input(OutputSize() + i); + + if (pos < 0) { + // discard + CAFFE_ENFORCE(numVisited_ >= numToCollect_); + } else if (pos >= tensorVector->size()) { + // append + tensorVector->push_back(Tensor()); + tensorVector->back().template CopyFrom( + tensor, &context_); + } else { + // replace + tensorVector->at(pos).template CopyFrom( + tensor, &context_); } } - if (pos < 0) { - // discard - CAFFE_ENFORCE(numVisited_ >= numToCollect_); - } else if (pos >= tensorVector->size()) { - // append - tensorVector->push_back(Tensor()); - tensorVector->back().template CopyFrom( - tensor, &context_); - } else { - // replace - tensorVector->at(pos).template CopyFrom( - tensor, &context_); - } - - position_out->Resize(vector()); - position_out->template mutable_data()[0] = pos; - numVisited_++; return true; } @@ -768,8 +767,6 @@ class CollectTensorOp final : public Operator { int numToCollect_; // number of tensors visited int numVisited_; - INPUT_TAGS(TENSOR_VECTOR_IN, TENSOR_TO_COLLECT, POSITION_IN); - OUTPUT_TAGS(TENSOR_VECTOR_OUT, POSITION_OUT); }; REGISTER_CPU_OPERATOR(CreateTreeCursor, CreateTreeCursorOp); @@ -1007,28 +1004,20 @@ along the first dimension. .Output(0, "tensor", "tensor after concatenating"); OPERATOR_SCHEMA(CollectTensor) - .NumInputs(2, 3) - .NumOutputs(2) - .EnforceInplace({{0, 0}}) - .AllowInplace({{2, 1}}) + .NumInputs([](int n) { return n > 0 && n % 2 == 0; }) + .NumOutputs(1, INT_MAX) + .NumInputsOutputs([](int in, int out) { return in == out * 2; }) + .EnforceInplace([](int in, int out) { return in == out; }) .SetDoc(R"DOC( Collect tensor into tensor vector by reservoir sampling, argument num_to_collect indicates the max number of tensors that will be -collcted - )DOC") - .Arg("num_to_collect", "The max number of tensors to collect") - .Input(0, "input tensor vector", "tensor vector with collected tensors") - .Input(1, "tensor", "new tensor will be collected by reservoir sampling") - .Input(2, "input position", R"DOC( -if provided, new tensor will be collected in the way indicated by position. -e.g. if position < 0, discard the new tensor, if position == k and k < the size -of input tensor vector, replace the tensor at position k with the new tensor. - )DOC") - .Output(0, "output tensor vector", "enforce inplace with input 0") - .Output(1, "output position", R"DOC( -record the position at which the new tensor was collcted, -position < 0 means it's discarded. - )DOC"); +collcted. The first half of the inputs are tensor vectors, which are also the +outputs. The second half of the inputs are the tensors to be collected into each +vector (in the same order). The input tensors are collected in all-or-none +manner. If they are collected, they will be placed at the same index in the +output vectors. +)DOC") + .Arg("num_to_collect", "The max number of tensors to collect"); SHOULD_NOT_DO_GRADIENT(CreateTreeCursor); SHOULD_NOT_DO_GRADIENT(ResetCursor); @@ -1044,4 +1033,83 @@ SHOULD_NOT_DO_GRADIENT(CollectTensor); } // namespace CAFFE_KNOWN_TYPE(std::unique_ptr); CAFFE_KNOWN_TYPE(TensorVectorPtr); + +namespace { + +class TreeCursorSerializer : public BlobSerializerBase { + public: + TreeCursorSerializer() {} + ~TreeCursorSerializer() {} + + void Serialize( + const Blob& blob, + const string& name, + SerializationAcceptor acceptor) override { + auto& cursor = blob.template Get>(); + BlobProto blob_proto; + + // serialize offsets as a tensor + if (cursor->offsets.size() > 0) { + Blob offsets_blob; + auto* offsets = offsets_blob.template GetMutable>(); + offsets->Resize(cursor->offsets.size()); + std::copy( + cursor->offsets.begin(), + cursor->offsets.end(), + offsets->mutable_data()); + TensorSerializer ser; + ser.Serialize( + *offsets, name, blob_proto.mutable_tensor(), 0, offsets->size()); + } + blob_proto.set_name(name); + blob_proto.set_type("std::unique_ptr"); + + // serialize field names in the content + std::ostringstream os; + for (const auto& field : cursor->it.fields()) { + os << field.name << " "; + } + blob_proto.set_content(os.str()); + + acceptor(name, blob_proto.SerializeAsString()); + } +}; + +class TreeCursorDeserializer : public BlobDeserializerBase { + public: + bool Deserialize(const BlobProto& proto, Blob* blob) override { + // deserialize the offsets + TensorDeserializer deser; + Blob offset_blob; + deser.Deserialize(proto, &offset_blob); + auto& offsets = offset_blob.template Get>(); + auto* offsets_ptr = offsets.data(); + + // deserialize the field names + std::vector fieldNames; + std::istringstream is(proto.content()); + std::string field; + while (true) { + is >> field; + if (is.eof()) { + break; + } + fieldNames.push_back(field); + } + TreeIterator it(fieldNames); + + auto* base = blob->template GetMutable>(); + (*base).reset(new TreeCursor(it)); + (*base)->offsets.assign(offsets_ptr, offsets_ptr + offsets.size()); + return true; + } +}; + +REGISTER_BLOB_SERIALIZER( + (TypeMeta::Id>()), + TreeCursorSerializer); +REGISTER_BLOB_DESERIALIZER(std::unique_ptr, TreeCursorDeserializer); + +} // namespace + } // caffe2 diff --git a/caffe2/operators/distance_op.cc b/caffe2/operators/distance_op.cc index 67dc55382f64..ec82918b17e9 100644 --- a/caffe2/operators/distance_op.cc +++ b/caffe2/operators/distance_op.cc @@ -7,9 +7,9 @@ bool SquaredL2DistanceOp::RunOnDevice() { auto& X = Input(0); auto& Y = Input(1); auto* distance = Output(0); - CAFFE_ENFORCE(X.ndim() == Y.ndim()); + CAFFE_ENFORCE_EQ(X.ndim(), Y.ndim()); for (int i = 0; i < X.ndim(); ++i) { - CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i)); + CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i)); } int N = X.ndim() > 0 ? X.dim32(0) : 1; int D = X.size() / N; @@ -35,9 +35,9 @@ bool DotProductOp::RunOnDevice() { auto& X = Input(X_IN); auto& Y = Input(Y_IN); auto* result = Output(DOT_OUT); - CAFFE_ENFORCE(X.ndim() == Y.ndim()); + CAFFE_ENFORCE_EQ(X.ndim(), Y.ndim()); for (int i = 0; i < X.ndim(); ++i) { - CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i)); + CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i)); } int N = X.ndim() > 0 ? X.dim32(0) : 1; int D = X.size() / N; @@ -58,9 +58,9 @@ bool CosineSimilarityOp::RunOnDevice() { auto& X = Input(X_IN); auto& Y = Input(Y_IN); auto* result = Output(COS_OUT); - CAFFE_ENFORCE(X.ndim() == Y.ndim()); + CAFFE_ENFORCE_EQ(X.ndim(), Y.ndim()); for (int i = 0; i < X.ndim(); ++i) { - CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i)); + CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i)); } int N = X.ndim() > 0 ? X.dim32(0) : 1; int D = X.size() / N; diff --git a/caffe2/operators/elementwise_op_schema.cc b/caffe2/operators/elementwise_op_schema.cc index b45049d442db..360b72b1b9ec 100644 --- a/caffe2/operators/elementwise_op_schema.cc +++ b/caffe2/operators/elementwise_op_schema.cc @@ -86,6 +86,10 @@ class GetAddGradient : public GradientMakerBase { vector{GI(1)}); } } + // Make sure the broadcast argument is not copied over. + bool CopyArguments() const override { + return false; + } }; REGISTER_GRADIENT(Add, GetAddGradient); @@ -113,6 +117,10 @@ class GetSubGradient : public GradientMakerBase { vector{GI(1)})}; } } + // Make sure the broadcast argument is not copied over. + bool CopyArguments() const override { + return false; + } }; REGISTER_GRADIENT(Sub, GetSubGradient); @@ -133,19 +141,27 @@ class GetMulGradient : public GradientMakerBase { } else { return vector{ CreateOperatorDef( - "Mul", "", vector{GO(0), I(1)}, vector{GI(0)}), + "Mul", + "mul_with_broadcast_grad_1", + vector{GO(0), I(1)}, + vector{GI(0)}, + vector{MakeArgument("broadcast", 1)}), CreateOperatorDef( "Mul", - "", + "mul_with_broadcast_grad_2", vector{GO(0), I(0)}, vector{GI(1) + "_autogen_pre_red"}), CreateOperatorDef( "SumReduceLike", - "", + "mul_with_broadcast_grad_3", vector{GI(1) + "_autogen_pre_red", I(1)}, vector{GI(1)})}; } } + // Make sure the broadcast argument is not copied over. + bool CopyArguments() const override { + return false; + } }; REGISTER_GRADIENT(Mul, GetMulGradient); diff --git a/caffe2/operators/elu_op.cc b/caffe2/operators/elu_op.cc new file mode 100644 index 000000000000..71da93d7544c --- /dev/null +++ b/caffe2/operators/elu_op.cc @@ -0,0 +1,81 @@ +#include "caffe2/operators/elu_op.h" + +#include "caffe2/utils/math.h" + +namespace caffe2 { + +template <> +bool EluOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + Y->ResizeLike(X); + const auto* Xdata = X.template data(); + auto* Ydata = Y->template mutable_data(); + ConstEigenVectorArrayMap Xvec(Xdata, X.size()); + EigenVectorArrayMap Yvec(Ydata, Y->size()); + Yvec = (Xvec > 0).select(Xvec, alpha_ * (Xvec.exp() - 1.0f)); + return true; +} + +template <> +bool EluGradientOp::RunOnDevice() { + auto& Y = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_GT(Y.size(), 0); + DCHECK_EQ(dY.size(), Y.size()); + dX->ResizeLike(Y); + + const float* Ydata = Y.data(); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + ConstEigenVectorArrayMap Yvec(Ydata, Y.size()); + ConstEigenVectorArrayMap dYvec(dYdata, dY.size()); + EigenVectorArrayMap dXvec(dXdata, dX->size()); + dXvec = (Yvec > 0).select(dYvec, dYvec * (Yvec + alpha_)); + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(Elu, EluOp); +REGISTER_CPU_OPERATOR(EluGradient, EluGradientOp); + +// Input: X, output: Y +OPERATOR_SCHEMA(Elu) + .NumInputs(1) + .NumOutputs(1) + .AllowInplace({{0, 0}}) + .SetDoc(R"DOC( + +Elu takes one input data (Tensor) and produces one output data +(Tensor) where the function `f(x) = alpha * (exp(x) - 1.) for x < +0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise. + +)DOC") + .Input(0, "X", "1D input tensor") + .Output(0, "Y", "1D input tensor"); + +// Input: Y, dY, output: dX +OPERATOR_SCHEMA(EluGradient) + .NumInputs(2) + .NumOutputs(1) + .AllowInplace({{1, 0}}) + .SetDoc(R"DOC( +EluGradient takes both Y and dY and uses this to update dX according to the +chain rule and derivatives of the rectified linear function. +)DOC"); + +class GetEluGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { + return SingleGradientDef( + def_.type() + "Gradient", + "", + vector{O(0), GO(0)}, + vector{GI(0)}); + } +}; +REGISTER_GRADIENT(Elu, GetEluGradient); + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/elu_op.h b/caffe2/operators/elu_op.h new file mode 100644 index 000000000000..0f2e7e7945a0 --- /dev/null +++ b/caffe2/operators/elu_op.h @@ -0,0 +1,37 @@ +#pragma once + +#include "caffe2/core/context.h" +#include "caffe2/core/logging.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class EluOp final : public Operator { + public: + EluOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + alpha_(OperatorBase::GetSingleArgument("alpha", 1.0)) {} + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + + protected: + T alpha_; +}; + +template +class EluGradientOp final : public Operator { + public: + EluGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + alpha_(OperatorBase::GetSingleArgument("alpha", 1.0)) {} + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + + protected: + T alpha_; +}; + +} // namespace caffe2 diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h index e5112d0bd9ca..75be8922d065 100644 --- a/caffe2/operators/fully_connected_op.h +++ b/caffe2/operators/fully_connected_op.h @@ -26,8 +26,8 @@ class FullyConnectedOp final : public Operator { CAFFE_ENFORCE(b.ndim() == 1, b.ndim()); // batch size const auto canonical_axis = X.canonical_axis_index(axis_); - const int M = X.size_to_dim(canonical_axis); - const int K = X.size_from_dim(canonical_axis); + const auto M = X.size_to_dim(canonical_axis); + const auto K = X.size_from_dim(canonical_axis); const int N = W.dim32(0); auto dimErrorString = [&]() { @@ -50,8 +50,7 @@ class FullyConnectedOp final : public Operator { }; // Error checking - CAFFE_ENFORCE(M * K == X.size(), dimErrorString()); - CAFFE_ENFORCE(K * N == W.size(), dimErrorString()); + CAFFE_ENFORCE(M == X.size() / K, dimErrorString()); CAFFE_ENFORCE(K == W.size() / W.dim32(0), dimErrorString()); CAFFE_ENFORCE(N == b.dim32(0), dimErrorString()); CAFFE_ENFORCE(N == b.size(), dimErrorString()); diff --git a/caffe2/operators/fully_connected_op_test.cc b/caffe2/operators/fully_connected_op_test.cc index 4457df0c0205..1342e80172c4 100644 --- a/caffe2/operators/fully_connected_op_test.cc +++ b/caffe2/operators/fully_connected_op_test.cc @@ -1,3 +1,5 @@ +// TODO(#14383029) cblas_sgemm not yet implemented on osmeta +#if !defined(__OSMETA__) #include #include "caffe2/operators/fully_connected_op.h" @@ -47,3 +49,4 @@ TEST(FullyConnectedTest, Test) { } } // namespace caffe2 +#endif diff --git a/caffe2/operators/h_softmax_op.cc b/caffe2/operators/h_softmax_op.cc index fb65f875aa6d..09784d4de089 100644 --- a/caffe2/operators/h_softmax_op.cc +++ b/caffe2/operators/h_softmax_op.cc @@ -55,6 +55,9 @@ float HSoftmaxOp::RunForwardSingle(const float* X, int_output_offset += dim_out; + if (target < 0) { + return -1; + } //Return cross entropy loss return -log(std::max(softmax_output_data[target], kLOG_THRESHOLD())); } @@ -84,8 +87,7 @@ bool HSoftmaxOp::RunOnDevice() { math::Set(M, 0.f, Ydata, &context_); const auto* labeldata = label.data(); - std::unordered_map hierarchy = getHierarchyForLabels(M, - labeldata, hierarchy_); + auto hierarchy = getHierarchyForLabels(M, labeldata, hierarchy_all_map_); int int_output_size = getIntermediateOutputSize(labeldata, M, hierarchy); intermediate_output->Resize(int_output_size); float * int_output_data = intermediate_output->mutable_data(); @@ -217,8 +219,7 @@ bool HSoftmaxGradientOp::RunOnDevice() { int K = X.size() / M; const auto* labeldata = label.data(); - std::unordered_map hierarchy = getHierarchyForLabels(M, - labeldata, hierarchy_); + auto hierarchy = getHierarchyForLabels(M, labeldata, hierarchy_all_map_); int output_offset = getIntermediateOutputSize(labeldata, M, hierarchy); //Traverse backward to access intermediate_output generated by HSoftmaxOp @@ -240,10 +241,180 @@ bool HSoftmaxGradientOp::RunOnDevice() { return true; } +// Implementation for the CPU context. +template <> +bool HSoftmaxSearchOp::pruning( + const float* X, + int sample, + int K, + const float* W, + const float* b, + const NodeProto& src_node, + NodeProto& dst_node, + float parent_score, + float beam) { + int w_length = src_node.children_size() + src_node.word_ids_size(); + Tensor intermediate_data; + intermediate_data.Resize(2 * w_length); + float* int_output_data = intermediate_data.template mutable_data(); + int int_output_offset = 0; + int w_offset = src_node.offset(); + + RunForwardSingle( + X + K * sample, + W + w_offset * K, + b + w_offset, + -1, + int_output_data, + bias_multiplier_.template data() + sample, + w_length, + K, + int_output_offset); + + float* softmax_output_data = int_output_data + w_length; + // real probabilities + for (int i = 0; i < w_length; i++) { + softmax_output_data[i] = + -log(std::max(softmax_output_data[i], kLOG_THRESHOLD())) + parent_score; + } + for (int i = 0; i < src_node.children_size(); i++) { + if (softmax_output_data[i] < parent_score + beam) { + dst_node.add_children(); + int idx = dst_node.children_size() - 1; + CAFFE_ENFORCE( + src_node.children(i).has_offset(), + "HSM Search require the field offset in NodeProte"); + dst_node.mutable_children(idx)->set_offset(src_node.children(i).offset()); + CAFFE_ENFORCE( + src_node.children(i).has_name(), + "HSM Search require the field name in NodeProte"); + dst_node.mutable_children(idx)->set_name(src_node.children(i).name()); + dst_node.add_scores(softmax_output_data[i]); + pruning( + X, + sample, + K, + W, + b, + src_node.children(i), + *dst_node.mutable_children(idx), + softmax_output_data[i], + beam); + } + } + + for (int i = src_node.children_size(); i < w_length; i++) { + if (softmax_output_data[i] < parent_score + beam) { + dst_node.add_word_ids(src_node.word_ids(i - src_node.children_size())); + dst_node.add_scores(softmax_output_data[i]); + } + } + + return true; +} + +template <> +bool HSoftmaxSearchOp::extractNodes( + const NodeProto& node, + std::vector>& info) { + int i = 0; + + for (const auto& n : node.children()) { + info.emplace_back(std::make_pair(n.name(), node.scores(i++))); + } + for (const int n : node.word_ids()) { + info.emplace_back(std::make_pair(caffe2::to_string(n), node.scores(i++))); + } + + for (const auto& n : node.children()) { + extractNodes(n, info); + } + return true; +} + +// Implementation for the CPU context. +template <> +bool HSoftmaxSearchOp::RunOnDevice() { + auto& X = Input(0); + const auto& W = Input(1); + const auto& b = Input(2); + auto* Y_names = Output(0); + auto* Y_scores = Output(1); + // Batch size + int M = X.ndim() > 1 ? X.dim32(0) : 1; + // Input feature dimension + int K = X.size() / M; + CAFFE_ENFORCE(W.ndim() == 2, "Weight must be a matrix."); // N*K + CAFFE_ENFORCE(b.ndim() == 1, "Bias must be a vector."); // N + CAFFE_ENFORCE(K == W.size() / (W.dim32(0)), "feature dimension mismatch."); + // Sum of output dimensions of all hierarchy nodes + int N = W.dim32(0); + CAFFE_ENFORCE(N == b.dim32(0), "mismatch between Weight and Bias."); + Y_names->Resize(M, top_n_); + Y_scores->Resize(M, top_n_); + + if (bias_multiplier_.size() != M) { + bias_multiplier_.Resize(M); + math::Set( + M, + static_cast(1), + bias_multiplier_.mutable_data(), + &context_); + } + + for (int sample = 0; sample < M; ++sample) { + CAFFE_ENFORCE( + tree_.root_node().has_offset(), + "HSM Search require the field offset in NodeProte"); + CAFFE_ENFORCE( + tree_.root_node().has_name(), + "HSM Search require the field name in NodeProte"); + + NodeProto dst_node; + dst_node.set_offset(tree_.root_node().offset()); + dst_node.set_name(tree_.root_node().name()); + + pruning( + X.data(), + sample, + K, + W.data(), + b.data(), + tree_.root_node(), + dst_node, + 0, + beam_); + + std::vector> info; + extractNodes(dst_node, info); + // saving the results for each sample. + std::partial_sort( + info.begin(), + info.begin() + (top_n_ < info.size() ? top_n_ : info.size() - 1), + info.end(), + [&](std::pair a, std::pair b) { + return a.second < b.second; + }); + auto* y_name_data = Y_names->mutable_data() + sample * top_n_; + auto* y_score_data = Y_scores->mutable_data() + sample * top_n_; + for (int i = 0; i < top_n_; i++) { + if (i < info.size()) { + y_name_data[i] = info[i].first; + y_score_data[i] = info[i].second; + } else { + y_score_data[i] = 0; + } + } + } + + return true; +} + namespace { REGISTER_CPU_OPERATOR(HSoftmax, HSoftmaxOp); REGISTER_CPU_OPERATOR(HSoftmaxGradient, HSoftmaxGradientOp); +REGISTER_CPU_OPERATOR(HSoftmaxSearch, HSoftmaxSearchOp); OPERATOR_SCHEMA(HSoftmax) .NumInputs(4) @@ -294,5 +465,36 @@ class GetHSoftmaxGradient : public GradientMakerBase { } }; REGISTER_GRADIENT(HSoftmax, GetHSoftmaxGradient); + +OPERATOR_SCHEMA(HSoftmaxSearch) + .NumInputs(3) + .NumOutputs(2) + .SetDoc(R"DOC( + HSoftmaxSearch is an operator to generate the most possible paths given a + well-trained model and input vector. Greedy algorithm is used for pruning the + search tree. + )DOC") + .Arg( + "tree", + "Serialized TreeProto string containing a tree " + "including all intermidate nodes and leafs. All nodes must have names " + "for correct outputs") + .Arg( + "beam", + "beam used for pruning tree. The pruning algorithm is that " + "only children, whose score is smaller than parent's score puls beam, " + "will be propagated. ") + .Arg("topN", "Number of nodes in outputs") + .Input(0, "X", "Input data from previous layer") + .Input(1, "W", "The matrix trained from Softmax Ops") + .Input(2, "b", "The bias traiend from Softmax Ops") + .Output( + 0, + "Y_names", + "The name of selected nodes and leafs. " + "For nodes, it will be the name defined in the tree. " + "For leafs, it will be the index of the word in the tree.") + .Output(1, "Y_scores", "The corresponding scores of Y_names"); +SHOULD_NOT_DO_GRADIENT(HSoftmaxSearch); } // namespace } // namespace caffe2 diff --git a/caffe2/operators/h_softmax_op.h b/caffe2/operators/h_softmax_op.h index df02a43b8583..9f55318afc58 100644 --- a/caffe2/operators/h_softmax_op.h +++ b/caffe2/operators/h_softmax_op.h @@ -9,23 +9,71 @@ namespace caffe2 { -template -class HSoftmaxOp final : public Operator { +template +class HSoftmaxOpBase : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - HSoftmaxOp(const OperatorDef& operator_def, Workspace* ws) + HSoftmaxOpBase(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) { - hierarchy_.ParseFromString( + HierarchyProto hierarchy; + hierarchy.ParseFromString( OperatorBase::GetSingleArgument("hierarchy", "")); + for (const auto& path : hierarchy.paths()) { + hierarchy_all_map_.emplace(path.word_id(), path); + } } - bool RunOnDevice() override; - private: - HierarchyProto hierarchy_; + protected: + std::unordered_map hierarchy_all_map_; Tensor scale_; Tensor sum_multiplier_; Tensor bias_multiplier_; - DISABLE_COPY_AND_ASSIGN(HSoftmaxOp); + static constexpr T kLOG_THRESHOLD() { + return 1e-20; + } + static std::unordered_map getHierarchyForLabels( + int M, + const int* labels, + const std::unordered_map& hierarchy_all_map) { + std::unordered_map hierarchy_map; + std::set label_set = std::set(labels, labels + M); + for (const auto& label : label_set) { + auto search = hierarchy_all_map.find(label); + CAFFE_ENFORCE(search != hierarchy_all_map.end(), "incorrect label."); + hierarchy_map.emplace(search->first, search->second); + } + return hierarchy_map; + } + int getIntermediateOutputSize( + const int* labels, + int M, + std::unordered_map& hierarchy) const { + int size = 0; + for (int label = 0; label < M; ++label) { + int word_id = labels[label]; + const auto& path = hierarchy[word_id]; + size += std::accumulate( + path.path_nodes().begin(), + path.path_nodes().end(), + 0, + // Output of FC + Output of Softmax + [](int sz, PathNodeProto node) { + return sz + 2 * node.length(); + }); + } + return size; + } +}; + +template +class HSoftmaxOp : public HSoftmaxOpBase { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + using HSoftmaxOpBase::HSoftmaxOpBase; + + bool RunOnDevice() override; + + protected: float RunForwardSingle( const float* X, const float* W, @@ -36,61 +84,16 @@ class HSoftmaxOp final : public Operator { int w_length, int K, int& output_offset); - static constexpr T kLOG_THRESHOLD() { - return 1e-20; - } - // TODO(Deepak): Make search more efficient, maybe? - static std::unordered_map getHierarchyForLabels( - int M, - const int* labels, - const HierarchyProto& hierarchy) { - std::unordered_map hierarchy_map; - std::set label_set = std::set(labels, labels + M); - for (const PathProto& path : hierarchy.paths()) { - if (label_set.count(path.word_id()) > 0) { - hierarchy_map.emplace(path.word_id(), path); - } - } - return hierarchy_map; - } - int getIntermediateOutputSize( - const int* labels, - int M, - std::unordered_map& hierarchy) { - int size = 0; - for (int label = 0; label < M; ++label) { - int word_id = labels[label]; - const auto& path = hierarchy[word_id]; - size += std::accumulate( - path.path_nodes().begin(), - path.path_nodes().end(), - 0, - // Output of FC + Output of Softmax - [](int size, PathNodeProto node) { - return size + 2 * node.length(); - }); - } - return size; - } }; template -class HSoftmaxGradientOp final : public Operator { +class HSoftmaxGradientOp final : public HSoftmaxOpBase { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - HSoftmaxGradientOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) { - hierarchy_.ParseFromString( - OperatorBase::GetSingleArgument("hierarchy", "")); - } + using HSoftmaxOpBase::HSoftmaxOpBase; bool RunOnDevice() override; private: - HierarchyProto hierarchy_; - Tensor scale_; - Tensor sum_multiplier_; - Tensor bias_multiplier_; - DISABLE_COPY_AND_ASSIGN(HSoftmaxGradientOp); void RunBackwardSingle( const float* X, const float* dY, @@ -104,42 +107,37 @@ class HSoftmaxGradientOp final : public Operator { int dim_in, int w_length, int& output_offset); - static constexpr T kLOG_THRESHOLD() { - return 1e-20; - } - // TODO(Deepak): Make search more efficient, maybe? - static std::unordered_map getHierarchyForLabels( - int M, - const int* labels, - const HierarchyProto& hierarchy) { - std::unordered_map hierarchy_map; - std::set label_set = std::set(labels, labels + M); - for (const PathProto& path : hierarchy.paths()) { - if (label_set.count(path.word_id()) > 0) { - hierarchy_map.emplace(path.word_id(), path); - } - } - return hierarchy_map; - } - int getIntermediateOutputSize( - const int* labels, - int M, - std::unordered_map& hierarchy) { - int size = 0; - for (int label = 0; label < M; ++label) { - int word_id = labels[label]; - const auto& path = hierarchy[word_id]; - size += std::accumulate( - path.path_nodes().begin(), - path.path_nodes().end(), - 0, - // Output of FC + Output of Softmax - [](int size, PathNodeProto node) { - return size + 2 * node.length(); - }); - } - return size; +}; + +template +class HSoftmaxSearchOp final : public HSoftmaxOp { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + HSoftmaxSearchOp(const OperatorDef& operator_def, Workspace* ws) + : HSoftmaxOp(operator_def, ws), + top_n_(OperatorBase::GetSingleArgument("topN", 5)), + beam_(OperatorBase::GetSingleArgument("beam", 0.01)) { + tree_.ParseFromString(OperatorBase::GetSingleArgument("tree", "")); } + bool RunOnDevice() override; + + private: + int top_n_; + float beam_; + TreeProto tree_; + bool pruning( + const float* X, + int sample, + int K, + const float* W, + const float* b, + const NodeProto& src_node, + NodeProto& dst_node, + float parent_score, + float beam); + bool extractNodes( + const NodeProto& node, + std::vector>& info); }; } // namespace caffe2 diff --git a/caffe2/operators/load_save_op.cc b/caffe2/operators/load_save_op.cc index 201f7deaf8a0..b8224b7cebe3 100644 --- a/caffe2/operators/load_save_op.cc +++ b/caffe2/operators/load_save_op.cc @@ -36,7 +36,11 @@ DBReader to load from, and we ignore the db and db_type arguments. "keep_device", "(int, default 0) if nonzero, the blobs are loaded into the device that " "is specified in the serialized BlobProto. Otherwise, the device will be " - "set as the one that the Load operator is being run under."); + "set as the one that the Load operator is being run under.") + .Arg( + "load_all", + "(int, default 0) if nonzero, will load all blobs pointed to by the db " + "to the workspace overwriting/creating blobs as needed."); OPERATOR_SCHEMA(Save).NumInputs(1, INT_MAX).NumOutputs(0) .SetDoc(R"DOC( diff --git a/caffe2/operators/load_save_op.h b/caffe2/operators/load_save_op.h index c7138ee542ea..38e4dd6bd108 100644 --- a/caffe2/operators/load_save_op.h +++ b/caffe2/operators/load_save_op.h @@ -29,24 +29,26 @@ class LoadOp final : public Operator { OperatorBase::GetSingleArgument("absolute_path", false)), db_name_(OperatorBase::GetSingleArgument("db", "")), db_type_(OperatorBase::GetSingleArgument("db_type", "")), - keep_device_(OperatorBase::GetSingleArgument("keep_device", 0)) { + keep_device_(OperatorBase::GetSingleArgument("keep_device", 0)), + load_all_(OperatorBase::GetSingleArgument("load_all", 0)) { if (InputSize() == 0) { CHECK_GT(db_name_.size(), 0) << "Must specify a db name."; CHECK_GT(db_type_.size(), 0) << "Must specify a db type."; } - int idx = 0; - for (const string& output_name : this->def().output()) { - output_indices_[output_name] = idx++; + if (!load_all_) { + int idx = 0; + for (const string& output_name : this->def().output()) { + output_indices_[output_name] = idx++; + } } } void SetCurrentDevice(BlobProto* proto); bool RunOnDevice() override { - const vector& outputs = OperatorBase::Outputs(); if (InputSize() == 1) { const db::DBReader& reader = OperatorBase::Input(0); - extractFrom(reader.cursor(), outputs); + extract(reader.cursor()); } else { string full_db_name = absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_); @@ -54,12 +56,50 @@ class LoadOp final : public Operator { caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ)); CAFFE_ENFORCE(in_db.get(), "Cannot open db: ", db_name_); std::unique_ptr cursor(in_db->NewCursor()); - extractFrom(cursor.get(), outputs); + extract(cursor.get()); } + return true; } private: + void extract(Cursor* cursor) { + if (load_all_) { + extractAll(cursor); + } else { + extractFrom(cursor, OperatorBase::Outputs()); + } + } + + void extractAll(Cursor* cursor) { + CAFFE_ENFORCE(cursor, "cursor is not valid"); + std::unordered_set seen_blobs; + for (; cursor->Valid(); cursor->Next()) { + const string& key = cursor->key(); + BlobProto proto; + CAFFE_ENFORCE( + proto.ParseFromString(cursor->value()), "Couldn't parse Proto"); + if (!keep_device_) { + // If we are not keeping the device as the one specified in the + // proto, we will set the current device. + SetCurrentDevice(&proto); + } + + if (seen_blobs.count(key) == 0 && ws_->GetBlob(key)) { + // This blob already exists, reset it, read below about why! + ws_->GetBlob(key)->Reset(); + } + + Blob* blob = ws_->CreateBlob(key); + CAFFE_ENFORCE(blob->Deserialize(proto), "Couldn't deserialize blob"); + if (!blob->IsType>()) { + // Only tensors can be seen multiple times as chunks. + CAFFE_ENFORCE(seen_blobs.count(key) == 0, "Blob duplicated"); + } + seen_blobs.insert(key); + } + } + void extractFrom(Cursor* cursor, const vector& outputs) { CHECK(cursor); @@ -155,6 +195,7 @@ class LoadOp final : public Operator { string db_name_; string db_type_; bool keep_device_; + bool load_all_; std::map output_indices_; }; @@ -188,6 +229,13 @@ class SaveOp final : public Operator { transaction->Put(blobName, data); transaction->Commit(); }; + std::set input_names; + for (int i = 0; i < inputs.size(); ++i) { + CAFFE_ENFORCE( + input_names.insert(def().input(i)).second, + "Duplicated feature: ", + def().input(i)); + } for (int i = 0; i < inputs.size(); ++i) { inputs[i]->Serialize(def().input(i), acceptor); } diff --git a/caffe2/operators/lp_pool_op.cc b/caffe2/operators/lp_pool_op.cc new file mode 100644 index 000000000000..9861f23a038c --- /dev/null +++ b/caffe2/operators/lp_pool_op.cc @@ -0,0 +1,273 @@ +// TODO: reduce the apparent redundancy of all the code below. +#include "caffe2/operators/pool_op.h" + +namespace caffe2 { + +using std::min; +using std::max; + +class LpPool {}; + +template <> +bool PoolOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto* Y = Output(0); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); + const auto p = OperatorBase::GetSingleArgument("p", 2.0); + const auto inv_p = 1.0 / p; + + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + math::Set(Y->size(), 0, Ydata, &context_); + // The main loop + int channels = X.dim32(1); + int height = X.dim32(2); + int width = X.dim32(3); + int pooled_height = Y->dim32(2); + int pooled_width = Y->dim32(3); + + for (int n = 0; n < X.dim32(0); ++n) { + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + const int pool_index = ph * pooled_width + pw; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int input_index = h * width + w; + Ydata[pool_index] += std::pow(std::abs(Xdata[input_index]), p); + } + } + Ydata[pool_index] = std::pow(Ydata[pool_index], inv_p); + } + } + // Do offset. + Xdata += height * width; + Ydata += pooled_height * pooled_width; + } + } + return true; +} + +template <> +bool PoolOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto* Y = Output(0); + int height = X.dim32(1); + int width = X.dim32(2); + int channels = X.dim32(3); + ConvPoolOpBase::SetOutputSize(X, Y, channels); + + const auto p = OperatorBase::GetSingleArgument("p", 2.0); + const auto inv_p = 1.0 / p; + + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + math::Set(Y->size(), 0, Ydata, &context_); + // The main loop + int pooled_height = Y->dim32(1); + int pooled_width = Y->dim32(2); + for (int n = 0; n < X.dim32(0); ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + const int pool_index = (ph * pooled_width + pw) * channels; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int input_index = (h * width + w) * channels; + for (int c = 0; c < channels; ++c) { + Ydata[pool_index + c] += + std::pow(std::abs(Xdata[input_index + c]), p); + } + } + } + for (int c = 0; c < channels; ++c) { + Ydata[pool_index + c] = std::pow(Ydata[pool_index + c], inv_p); + } + } + } + // Do offset. + Xdata += X.size() / X.dim32(0); + Ydata += Y->size() / Y->dim32(0); + } + return true; +} + +template <> +bool PoolGradientOp::RunOnDeviceWithOrderNCHW() { + const auto& X = Input(0); + const auto& Y = Input(1); + auto& dY = Input(2); + auto* dX = Output(0); + const auto p = OperatorBase::GetSingleArgument("p", 2.0); + const auto inv_p = 1.0 / p; + + // TODO(Yangqing): Add shape checks. + dX->ResizeLike(X); + math::Set( + X.size(), 0, dX->mutable_data(), &context_); + const float* dYdata = dY.data(); + const float* Xdata = X.data(); + const float* Ydata = Y.data(); + float* dXdata = dX->mutable_data(); + + int channels = X.dim32(1); + CHECK_EQ(channels, dY.dim32(1)); + int height = X.dim32(2); + int width = X.dim32(3); + ConvPoolOpBase::ComputePads(height, width); + int pooled_height = dY.dim32(2); + int pooled_width = dY.dim32(3); + // The main loop + for (int n = 0; n < X.dim32(0); ++n) { + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + float scale = 1. / (hend - hstart) / (wend - wstart); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + // gradient of p-norm is x_j * |x_j|^{p-2} / |x|_p^{p-1} + dXdata[h * width + w] += dYdata[ph * pooled_width + pw] * + Xdata[h * width + w] * + std::pow(std::abs(Xdata[h * width + w]), p - 2) / + std::pow(Ydata[ph * pooled_width + pw], p - 1); + } + } + } + } + // offset + dXdata += height * width; + dYdata += pooled_height * pooled_width; + Ydata += pooled_height * pooled_width; + Xdata += height * width; + } + } + return true; +} + +template <> +bool PoolGradientOp::RunOnDeviceWithOrderNHWC() { + const auto& X = Input(0); + const auto& Y = Input(1); + auto& dY = Input(2); + CHECK_EQ(dY.ndim(), 4); + auto* dX = Output(0); + // TODO(Yangqing): Add shape checks. + dX->ResizeLike(X); + math::Set( + X.size(), 0, dX->mutable_data(), &context_); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + const float* Xdata = X.data(); + const float* Ydata = Y.data(); + // The main loop + int height = X.dim32(1); + int width = X.dim32(2); + ConvPoolOpBase::ComputePads(height, width); + const auto p = OperatorBase::GetSingleArgument("p", 2.0); + const auto inv_p = 1.0 / p; + + int pooled_height = dY.dim32(1); + int pooled_width = dY.dim32(2); + int channels = X.dim32(3); + CHECK_EQ(channels, dY.dim32(3)); + for (int n = 0; n < X.dim32(0); ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + float scale = 1. / (hend - hstart) / (wend - wstart); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + for (int c = 0; c < channels; ++c) { + dXdata[(h * width + w) * channels + c] += + dYdata[(ph * pooled_width + pw) * channels + c] * + Xdata[(h * width + w) * channels + c] * + std::pow( + std::abs(Xdata[(h * width + w) * channels + c]), p - 2) / + std::pow( + Ydata[(ph * pooled_width + pw) * channels + c], p - 1); + } + } + } + } + } + // offset + dXdata += X.size() / X.dim32(0); + dYdata += dY.size() / dY.dim32(0); + Xdata += X.size() / X.dim32(0); + Ydata += Y.size() / Y.dim32(0); + } + return true; +} + +namespace { + +REGISTER_CPU_OPERATOR(LpPool, PoolOp); +REGISTER_CPU_OPERATOR( + LpPoolGradient, + PoolGradientOp); + +OPERATOR_SCHEMA(LpPool) + .NumInputs(1) + .NumOutputs(1) + .SetDoc(R"DOC( + +LpPool consumes an input blob X and applies L-p pooling across the +the blob according to kernel sizes, stride sizes, and pad lengths defined by the +ConvPoolOpBase operator. L-p pooling consisting of taking the L-p norm of a +subset of the input tensor according to the kernel size and downsampling the +data into the output blob Y for further processing. + + )DOC") + .Input( + 0, + "X", + "Input data tensor from the previous operator; dimensions " + "depend on whether the NCHW or NHWC operators are being used. For example, " + "in the former, the input has size (N x C x H x W), where N is the batch " + "size, C is the number of channels, and H and W are the height and the width " + "of the data. The corresponding permutation of dimensions is used in the " + "latter case. ") + .Output( + 0, + "Y", + "Output data tensor from L-p pooling across the input " + "tensor. Dimensions will vary based on various kernel, stride, and pad " + "sizes."); + +OPERATOR_SCHEMA(LpPoolGradient).NumInputs(3).NumOutputs(1); + +class GetPoolGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { + return SingleGradientDef( + def_.type() + "Gradient", + "", + vector{I(0), O(0), GO(0)}, + vector{GI(0)}); + } +}; +REGISTER_GRADIENT(LpPool, GetPoolGradient); +} +} diff --git a/caffe2/operators/lp_pool_op.cu b/caffe2/operators/lp_pool_op.cu new file mode 100644 index 000000000000..81b1d07d6252 --- /dev/null +++ b/caffe2/operators/lp_pool_op.cu @@ -0,0 +1,349 @@ +// TODO: reduce the apparent redundancy of all the code below. +#include + +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/pool_op.h" + +namespace caffe2 { +namespace { +class LpPool {}; +} // namespace + +namespace { +template +__global__ void LpPoolForwardNCHW( + const int nthreads, + const T* bottom_data, + const int num, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_t, + const int pad_l, + T* top_data, + const T p) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n = index; + int pw = n % pooled_width; + n /= pooled_width; + int ph = n % pooled_height; + n /= pooled_height; + int c = n % channels; + n /= channels; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + top_data[index] = 0; + int bottom_offset = (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + top_data[index] += + std::pow(std::abs(bottom_data[bottom_offset + h * width + w]), p); + } + } + top_data[index] = std::pow(top_data[index], 1.0 / p); + } +} + +template +__global__ void LpPoolForwardNHWC( + const int nthreads, + const T* bottom_data, + const int num, + const int height, + const int width, + const int channels, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_t, + const int pad_l, + T* top_data, + const T p) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int c = index % channels; + int pw = (index / channels) % pooled_width; + int ph = (index / channels / pooled_width) % pooled_height; + int n = index / channels / pooled_width / pooled_height; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + T output = 0; + int bottom_offset = n * height * width * channels + c; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + output += std::pow( + std::abs(bottom_data[bottom_offset + (h * width + w) * channels]), + p); + } + } + top_data[index] = std::pow(output, 1.0 / p); + } +} + +template +__global__ void LpPoolBackwardNCHW( + const int nthreads, + const T* const top_diff, + const T* const top_data, + const T* const bottom_data, + const int num, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_t, + const int pad_l, + T* const bottom_diff, + const int p) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + const int w = index % width + pad_l; + const int h = (index / width) % height + pad_t; + const int c = (index / width / height) % channels; + const int n = index / width / height / channels; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + T gradient = 0; + const T* const top_diff_slice = + top_diff + (n * channels + c) * pooled_height * pooled_width; + const T* const top_data_slice = + top_data + (n * channels + c) * pooled_height * pooled_width; + + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + gradient += top_diff_slice[ph * pooled_width + pw] * + bottom_data[index] * std::pow(std::abs(bottom_data[index]), p - 2) / + std::pow(top_data_slice[ph * pooled_width + pw], p - 1); + } + } + bottom_diff[index] = gradient; + } +} + +template +__global__ void LpPoolBackwardNHWC( + const int nthreads, + const T* const top_diff, + const T* const top_data, + const T* const bottom_data, + const int num, + const int height, + const int width, + const int channels, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_t, + const int pad_l, + T* const bottom_diff, + const T p) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + const int c = index % channels; + const int w = index / channels % width + pad_l; + const int h = (index / channels / width) % height + pad_t; + const int n = index / channels / width / height; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + T gradient = 0; + const T* const top_diff_slice = + top_diff + n * pooled_height * pooled_width * channels + c; + const T* const top_data_slice = + top_data + n * pooled_height * pooled_width * channels + c; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + gradient += top_diff_slice[(ph * pooled_width + pw) * channels] * + bottom_data[index] * std::pow(std::abs(bottom_data[index]), p - 2) / + std::pow(top_data_slice[(ph * pooled_width + pw) * channels], + p - 1); + } + } + bottom_diff[index] = gradient; + } +} + +} // namespace + +template <> +bool PoolOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto* Y = Output(0); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); + int output_size = Y->size(); + LpPoolForwardNCHW<<< + CAFFE_GET_BLOCKS(output_size), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + output_size, + X.data(), + X.dim32(0), + X.dim32(1), + X.dim32(2), + X.dim32(3), + Y->dim32(2), + Y->dim32(3), + kernel_h_, + kernel_w_, + stride_h_, + stride_w_, + pad_t_, + pad_l_, + Y->mutable_data(), + OperatorBase::GetSingleArgument("p", 2.0)); + return true; +} + +template <> +bool PoolOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto* Y = Output(0); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(3)); + int output_size = Y->size(); + LpPoolForwardNHWC<<< + CAFFE_GET_BLOCKS(output_size), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + output_size, + X.data(), + X.dim32(0), + X.dim32(1), + X.dim32(2), + X.dim32(3), + Y->dim32(1), + Y->dim32(2), + kernel_h_, + kernel_w_, + stride_h_, + stride_w_, + pad_t_, + pad_l_, + Y->mutable_data(), + OperatorBase::GetSingleArgument("p", 2.0)); + return true; +} + +template <> +bool PoolGradientOp:: + RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto& Y = Input(1); + auto& dY = Input(2); + CHECK_EQ(dY.ndim(), 4); + auto* dX = Output(0); + dX->ResizeLike(X); + ConvPoolOpBase::ComputePads(X.dim32(2), X.dim32(3)); + LpPoolBackwardNCHW<<< + CAFFE_GET_BLOCKS(X.size()), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + X.size(), + dY.data(), + Y.data(), + X.data(), + X.dim32(0), + X.dim32(1), + X.dim32(2), + X.dim32(3), + dY.dim32(2), + dY.dim32(3), + kernel_h_, + kernel_w_, + stride_h_, + stride_w_, + pad_t_, + pad_l_, + dX->mutable_data(), + OperatorBase::GetSingleArgument("p", 2.0)); + return true; +} + +template <> +bool PoolGradientOp:: + RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto& Y = Input(1); + auto& dY = Input(2); + CHECK_EQ(dY.ndim(), 4); + auto* dX = Output(0); + dX->ResizeLike(X); + ConvPoolOpBase::ComputePads(X.dim32(1), X.dim32(2)); + LpPoolBackwardNHWC<<< + CAFFE_GET_BLOCKS(X.size()), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + X.size(), + dY.data(), + Y.data(), + X.data(), + X.dim32(0), + X.dim32(1), + X.dim32(2), + X.dim32(3), + dY.dim32(1), + dY.dim32(2), + kernel_h_, + kernel_w_, + stride_h_, + stride_w_, + pad_t_, + pad_l_, + dX->mutable_data(), + OperatorBase::GetSingleArgument("p", 2.0)); + return true; +} + +namespace { +REGISTER_CUDA_OPERATOR(LpPool, PoolOp); +REGISTER_CUDA_OPERATOR( + LpPoolGradient, + PoolGradientOp); +} +} diff --git a/caffe2/operators/metrics_ops.cc b/caffe2/operators/metrics_ops.cc new file mode 100644 index 000000000000..923236ec0857 --- /dev/null +++ b/caffe2/operators/metrics_ops.cc @@ -0,0 +1,53 @@ +#include "caffe2/operators/metrics_ops.h" + +namespace caffe2 { +namespace { +REGISTER_CPU_OPERATOR(CreateQPSMetric, CreateQPSMetricOp); +REGISTER_CPU_OPERATOR(QPSMetric, QPSMetricOp); +REGISTER_CPU_OPERATOR(QPSMetricReport, QPSMetricReportOp); + +OPERATOR_SCHEMA(CreateQPSMetric) + .NumInputs(0) + .NumOutputs(1) + .SetDoc(R"DOC( +CreateQPSMetric operator create a blob that will store state that is required +for computing QPSMetric. The only output of the operator will have blob with +QPSMetricState as an output. +)DOC") + .Output(0, "output", "Blob with QPSMetricState"); + +OPERATOR_SCHEMA(QPSMetric) + .NumInputs(2) + .NumOutputs(1) + .SetDoc(R"DOC( +QPSMetric operator syncronously updates metric storedcreate a blob that will +store state that is required for computing QPSMetric. The only output of the +operator will have blob with QPSMetricState as an output. +)DOC") + .Input( + 0, + "QPS_METRIC_STATE", + "Input Blob QPSMetricState, that needs to be updated") + .Input( + 1, + "INPUT_BATCH", + "Input Blob containing a tensor with batch of the examples." + " First dimension of the batch will be used to get the number of" + " examples in the batch.") + .Output(0, "output", "Blob with QPSMetricState") + .EnforceInplace({{0, 0}}); + +OPERATOR_SCHEMA(QPSMetricReport) + .NumInputs(1) + .NumOutputs(0) + .SetDoc(R"DOC( +QPSMetricReport operator that syncronously consumes the QPSMetricState blob and +reports the information about QPS. +)DOC") + .Output(0, "output", "Blob with QPSMetricState"); + +SHOULD_NOT_DO_GRADIENT(CreateQPSMetric); +SHOULD_NOT_DO_GRADIENT(QPSMetric); +SHOULD_NOT_DO_GRADIENT(QPSMetricReport); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/metrics_ops.h b/caffe2/operators/metrics_ops.h new file mode 100644 index 000000000000..ee2daaa129e4 --- /dev/null +++ b/caffe2/operators/metrics_ops.h @@ -0,0 +1,85 @@ +#pragma once + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/timer.h" + +#include + +namespace caffe2 { +namespace { +struct QPSMetricState { + Timer lifetimeTimer; + Timer windowTimer; + int64_t windowExamples{0}; + int64_t lifetimeExamples{0}; + + std::mutex mutex; +}; +} + +CAFFE_KNOWN_TYPE(std::unique_ptr); + +// TODO(amalevich): Consider making all the code below templated, so it'll be +// easier to share it across different metrics. +class CreateQPSMetricOp final : public Operator { + public: + using Operator::Operator; + + bool RunOnDevice() override { + *OperatorBase::Output>(0) = + caffe2::make_unique(); + return true; + } +}; + +class QPSMetricOp final : public Operator { + public: + using Operator::Operator; + + bool RunOnDevice() override { + auto& metricsBlob = + *OperatorBase::Input>(0); + auto examples = Input(1).dim(0); + // All changes to metrics should happen under critical section. + { + std::lock_guard guard(metricsBlob.mutex); + metricsBlob.windowExamples += examples; + metricsBlob.lifetimeExamples += examples; + } + return true; + } +}; + +class QPSMetricReportOp final : public Operator { + public: + using Operator::Operator; + + bool RunOnDevice() override { + auto& metricsBlob = + *OperatorBase::Input>(0); + // All changes to metrics should happen under critical section. + float windowSeconds = -1; + int64_t windowExamples = 0; + float lifetimeSeconds = -1; + int64_t lifetimeExamples = 0; + { + std::lock_guard guard(metricsBlob.mutex); + windowSeconds = metricsBlob.windowTimer.Seconds(); + lifetimeSeconds = metricsBlob.lifetimeTimer.Seconds(); + windowExamples = metricsBlob.windowExamples; + lifetimeExamples = metricsBlob.lifetimeExamples; + + metricsBlob.windowTimer.Start(); + metricsBlob.windowExamples = 0; + } + // TODO(amalevich): Add output blobs, so it would be relatively easy to + // access this metrics from the outside + LOG(INFO) << "Overal QPS = " + << (static_cast(lifetimeExamples) / lifetimeSeconds) + << ", Window QPS = " + << (static_cast(windowExamples) / windowSeconds); + return true; + } +}; +} diff --git a/caffe2/operators/pack_segments.cc b/caffe2/operators/pack_segments.cc index 55333d34e6ce..b1b2c6d77458 100644 --- a/caffe2/operators/pack_segments.cc +++ b/caffe2/operators/pack_segments.cc @@ -5,6 +5,7 @@ #include #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" +#include "caffe2/utils/math.h" namespace caffe2 { @@ -54,9 +55,12 @@ class PackSegmentsOp final : public Operator { shape.insert(shape.begin(), lengths.size()); output->Resize(shape); - // Do zero padding - float* data_ptr = output->template mutable_data(); - memset(data_ptr, padding_, sizeof(float) * output->size()); + // Do padding + math::Set( + output->size(), + padding_, + output->template mutable_data(), + &context_); int block_size = data.size() / data.dim(0); int block_bytesize = data.nbytes() / data.dim(0); diff --git a/caffe2/operators/packed_fc_op.cc b/caffe2/operators/packed_fc_op.cc index a4d392a12ddf..77ffaa853f61 100644 --- a/caffe2/operators/packed_fc_op.cc +++ b/caffe2/operators/packed_fc_op.cc @@ -17,7 +17,21 @@ class PackedFCOp final : public Operator { USE_OPERATOR_FUNCTIONS(CPUContext); PackedFCOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), - axis_(OperatorBase::GetSingleArgument("axis", 1)) {} + axis_(OperatorBase::GetSingleArgument("axis", 1)) { + OPERATOR_NEEDS_FEATURE( + __builtin_cpu_supports("avx2") || operator_def.type() == "PackedFC", + "If you are trying to use PackedFCOp as a FC with PACKED engine on " + "a machine that does not have avx2, be noted that the functionality " + "is not tuned and you are better off directly using FC."); + // TODO(jiayq): after MKL update, remove this constraint. This is different + // from the check above, as the above is a performance hint and the below + // is about correctness. + CAFFE_ENFORCE( + __builtin_cpu_supports("avx2"), + "Do not run PackedFC on a machine that does not have avx2 " + "right now, as there is an known issue with MKL 2017.0.098 " + "that produces wrong results on non-avx2 machines."); + } ~PackedFCOp() {} bool RunOnDevice() override { @@ -50,35 +64,47 @@ class PackedFCOp final : public Operator { if (!local_packed_matrix_.get() || local_packed_matrix_->n_ != M) { // If there is no pre packed matrix, or the batch size changed, we // do a re-pack. - // Note that the packed sgemm follows the blas interfaces, not cblas local_packed_matrix_.reset(new MKLPackedMatrix( - 'A', 'T', N, M, K, 1.f, W.template data(), K)); + CblasBMatrix, + CblasTrans, + M, + N, + K, + 1.f, + W.template data(), + K)); } packed_matrix = local_packed_matrix_.get(); } else if (OperatorBase::InputIsType(1)) { packed_matrix = &OperatorBase::Input(1); } - CAFFE_ENFORCE_EQ(packed_matrix->m_, N); + CAFFE_ENFORCE_EQ(packed_matrix->m_, M); CAFFE_ENFORCE_EQ(packed_matrix->k_, K); - CAFFE_ENFORCE_EQ(packed_matrix->n_, M); + CAFFE_ENFORCE_EQ(packed_matrix->n_, N); // Do we want to check the other flags as well? - Y->Resize(M, N); + Y_shape_cache_ = X.dims(); + // This is an invariant of canonical_axis, so we can DCHECK. + DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); + Y_shape_cache_.resize(canonical_axis + 1); + Y_shape_cache_[canonical_axis] = N; + Y->Resize(Y_shape_cache_); + CAFFE_ENFORCE(M * N == Y->size()); - const float kZero = 0; - sgemm_compute( - "P", - "N", - &N, - &M, - &K, - packed_matrix->data_, - &K, + cblas_sgemm_compute( + CblasRowMajor, + CblasNoTrans, + CblasPacked, + M, + N, + K, X.template data(), - &K, - &kZero, + K, + packed_matrix->data_, + K, + 0, Y->template mutable_data(), - &N); + N); // Add bias term if (bias_multiplier_.size() != M) { @@ -113,6 +139,7 @@ class PackedFCOp final : public Operator { } size_t axis_{1}; uint32_t hash_{0}; + vector Y_shape_cache_; Tensor bias_multiplier_; std::unique_ptr local_packed_matrix_; }; @@ -120,6 +147,7 @@ class PackedFCOp final : public Operator { } // namespace mkl REGISTER_CPU_OPERATOR(PackedFC, mkl::PackedFCOp); +REGISTER_CPU_OPERATOR_WITH_ENGINE(FC, PACKED, mkl::PackedFCOp); OPERATOR_SCHEMA(PackedFC).NumInputs(3).NumOutputs(1).SetDoc(R"DOC( Computes the result of passing an input vector X into a fully connected diff --git a/caffe2/operators/partition_ops.cc b/caffe2/operators/partition_ops.cc index 84a0b0496a90..ecbdf65c3482 100644 --- a/caffe2/operators/partition_ops.cc +++ b/caffe2/operators/partition_ops.cc @@ -6,13 +6,12 @@ namespace { REGISTER_CPU_OPERATOR(Partition, PartitionOp); REGISTER_CPU_OPERATOR(LengthsPartition, LengthsPartitionOp); -OPERATOR_SCHEMA(Shard) +OPERATOR_SCHEMA(Partition) .NumInputsOutputs([](int in, int out) { return in > 0 && out > 0 && out % in == 0; }) .SetDoc(R"DOC( -Sharding splits the input int tensor into multiple ones according to the first -tensor. +Splits the input int tensor into multiple ones according to the first tensor. Takes the first input and partitions it to shards according to the remainder of values modulo the number of partitions. It requires that the first tensor is of @@ -35,21 +34,21 @@ X_0_part_0, X_1_part_0, ..., X_N-1_part_0, X_0_part_1, ..., X_N-1_part_K-1 .Input( 0, "input", - "Input tensor containing data to be sharded. The " + "Input tensor containing data to be partitioned. The " "number of input tensors might be greater than 1 but must have the " "same shape as the previous tensors.") .Output( 0, - "shards", - "Output Shards. The number of output shards has to be a " - "multiple of the number of input shards."); + "partitions", + "Output Partitions. The number of output tensors has to be a " + "multiple of the number of input tensors."); -OPERATOR_SCHEMA(LengthsSharding) +OPERATOR_SCHEMA(LengthsPartition) .NumInputsOutputs([](int in, int out) { return in >= 2 && out > 0 && out % in == 0; }) .SetDoc(R"DOC( -LengthsSharding splits the input int tensor into multiple ones according to the +LengthsPartition splits the input int tensor into multiple ones according to the second tensor. The first dimension is expected to be the tensor that describes lengths of the elements. @@ -76,19 +75,19 @@ X_0_part_0, X_1_part_0, ..., X_N-1_part_0, X_0_part_1, ..., X_N-1_part_K-1 .Input( 0, "input", - "Input tensor containing data to be sharded. The " + "Input tensor containing data to be partitioned. The " "number of input tensors might be greater than 1 but must have the " "same shape as the previous tensors.") .Output( 0, - "shards", - "Output Shards. The number of output shards has to be a " - "multiple of the number of input shards."); + "partitions", + "Output Partitions. The number of output tensors has to be a " + "multiple of the number of input tensors."); // This should actually have gradient, but for now nothing uses it. // Because gradient computation right now is not input/output aware it can't be // GRADIENT_NOT_IMPLEMENTEDYET -NO_GRADIENT(Sharding); -NO_GRADIENT(ShardingLengths); +NO_GRADIENT(Partition); +NO_GRADIENT(LengthsPartition); } // namespace } // namespace caffe2 diff --git a/caffe2/operators/pool_op.cc b/caffe2/operators/pool_op.cc index d75c0ef025a2..365537388eaa 100644 --- a/caffe2/operators/pool_op.cc +++ b/caffe2/operators/pool_op.cc @@ -1,5 +1,6 @@ // TODO: reduce the apparent redundancy of all the code below. #include "caffe2/operators/pool_op.h" +#include "caffe2/utils/cpu_neon.h" namespace caffe2 { @@ -11,6 +12,154 @@ namespace { // template to instantiate the different algorithms. class AveragePool {}; class MaxPool {}; + +#ifdef __ARM_NEON__ + +bool isNeonEligible(int inputH, int inputW, + int outputH, int outputW, + int kH, int kW, + int strideH, int strideW, + int padT, int padL, int padB, int padR, + int dilationH, int dilationW, + const float* input, + float* output) { + // Use this kernel only if: + // Kernel width is 4x4 + // Kernel stride is 4x4 + // Padding is 0 + // Dilation is 1 + // Output width and height are even divisors of input width + // Input width and height are divisible by 4 (should be implied by + // all of the above, but just check again) + // Input and output pointers are aligned by float32x4_t + + bool kernelOk = (kH == 4) && (kW == 4); + bool strideOk = (strideH == 4) && (strideW == 4); + bool padOk = (padT == 0) && (padL == 0) && (padB == 0) && (padR == 0); + bool dilationOk = (dilationH == 1) && (dilationW == 1); + + bool outputOk = ((inputH % outputH) == 0) && ((inputW % outputW) == 0); + bool inputOk = (inputW % 4 == 0) && (inputH % 4 == 0); + bool alignOk = isPointerAligned(input, sizeof(float32x4_t)) && + isPointerAligned(output, sizeof(float32x4_t)); + + return kernelOk && strideOk && padOk && dilationOk && + outputOk && inputOk && alignOk; +} + +// Vectorizes 4x4p0s0 averge pooling for ARM NEON +void avgPoolNeon4x4p0s0Plane(int inputH, int inputW, + const float* input, + float* output) { + constexpr int kKernelHeight = 4; + constexpr int kKernelWidth = 4; + constexpr float kDiv = + (1.0f / ((float) kKernelHeight * (float) kKernelWidth)); + + // Handle portion that can be unrolled by 4 + constexpr int kUnroll = 4; + constexpr int kLoadSizeFloat = (sizeof(float32x4_t) / sizeof(float)); + constexpr int kLoadCols = kUnroll * kLoadSizeFloat; + + if (inputW % kLoadCols == 0) { + // + // Manually unroll by 4 (kUnroll) + // + + for (int h = 0; h < inputH; h += kKernelHeight) { + float* outputRow = output + (h / kKernelHeight) * (inputW / kKernelWidth); + const float* curInput = input + h * inputW; + + for (int w = 0; w < inputW; w += kLoadCols) { + float32x4_t out = {}; + + { + float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW); + float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW); + float32x4_t v0_2 = vld1q_f32_aligned(curInput + 2 * inputW); + float32x4_t v0_3 = vld1q_f32_aligned(curInput + 3 * inputW); + float v0 = horizontal_sum_f32(v0_0, v0_1, v0_2, v0_3); + out = vsetq_lane_f32(v0, out, 0); + } + curInput += kLoadSizeFloat; + + { + float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW); + float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW); + float32x4_t v0_2 = vld1q_f32_aligned(curInput + 2 * inputW); + float32x4_t v0_3 = vld1q_f32_aligned(curInput + 3 * inputW); + float v0 = horizontal_sum_f32(v0_0, v0_1, v0_2, v0_3); + out = vsetq_lane_f32(v0, out, 1); + } + curInput += kLoadSizeFloat; + + { + float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW); + float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW); + float32x4_t v0_2 = vld1q_f32_aligned(curInput + 2 * inputW); + float32x4_t v0_3 = vld1q_f32_aligned(curInput + 3 * inputW); + float v0 = horizontal_sum_f32(v0_0, v0_1, v0_2, v0_3); + out = vsetq_lane_f32(v0, out, 2); + } + curInput += kLoadSizeFloat; + + { + float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW); + float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW); + float32x4_t v0_2 = vld1q_f32_aligned(curInput + 2 * inputW); + float32x4_t v0_3 = vld1q_f32_aligned(curInput + 3 * inputW); + float v0 = horizontal_sum_f32(v0_0, v0_1, v0_2, v0_3); + out = vsetq_lane_f32(v0, out, 3); + } + curInput += kLoadSizeFloat; + + out = vmulq_f32(out, vdupq_n_f32(kDiv)); + vst1q_f32_aligned(&outputRow[w / kKernelWidth], out); + } + } + } else { + // + // Not unrolled + // + + for (int h = 0; h < inputH; h += kKernelHeight) { + const float* inputRow = input + h * inputW; + float* outputRow = output + (h / kKernelHeight) * (inputW / kKernelWidth); + + for (int w = 0; w < inputW; w += kKernelWidth) { + const float* curInput = inputRow + w; + + float32x4_t v0_0 = vld1q_f32_aligned(curInput + 0 * inputW); + float32x4_t v0_1 = vld1q_f32_aligned(curInput + 1 * inputW); + float32x4_t v0_2 = vld1q_f32_aligned(curInput + 2 * inputW); + float32x4_t v0_3 = vld1q_f32_aligned(curInput + 3 * inputW); + float v0 = horizontal_sum_f32(v0_0, v0_1, v0_2, v0_3) * kDiv; + outputRow[w / kKernelWidth] = v0; + } + } + } +} + +void +runNeonAveragePool4x4p0s0NCHW(int N, int C, int inputH, int inputW, + const float* input, + float* output) { + // We only have the 4x4p0s0 implementation at present, which is + // checked at a higher level + int outputH = inputH / 4; + int outputW = inputW / 4; + + for (int n = 0; n < N; ++n) { + for (int c = 0; c < C; ++c) { + const float* curInput = input + (n * C + c) * inputH * inputW; + float* curOutput = output + (n * C + c) * outputH * outputW; + + avgPoolNeon4x4p0s0Plane(inputH, inputW, curInput, curOutput); + } + } +} +#endif // __ARM_NEON__ + } // namespace template <> @@ -29,6 +178,23 @@ bool PoolOp::RunOnDeviceWithOrderNCHW() { int width = X.dim32(3); int pooled_height = Y->dim32(2); int pooled_width = Y->dim32(3); + +#ifdef __ARM_NEON__ + // We specialize certain variants on ARM for vectorization + if (isNeonEligible(X.dim32(2), X.dim32(3), + Y->dim32(2), Y->dim32(3), + kernel_h_, kernel_w_, + stride_h_, stride_w_, + pad_t_, pad_l_, pad_b_, pad_r_, + dilation_h_, dilation_w_, + Xdata, Ydata)) { + runNeonAveragePool4x4p0s0NCHW(X.dim32(0), X.dim32(1), + X.dim32(2), X.dim32(3), + Xdata, Ydata); + return true; + } +#endif // __ARM_NEON__ + for (int n = 0; n < X.dim32(0); ++n) { for (int c = 0; c < channels; ++c) { for (int ph = 0; ph < pooled_height; ++ph) { diff --git a/caffe2/operators/prelu_op.cc b/caffe2/operators/prelu_op.cc new file mode 100644 index 000000000000..3ac9d2c36448 --- /dev/null +++ b/caffe2/operators/prelu_op.cc @@ -0,0 +1,300 @@ +#include "caffe2/operators/prelu_op.h" + +#include "caffe2/utils/cpu_neon.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +#ifdef __ARM_NEON__ +namespace { + +void runNeonPrelu(float* out, const float* in, int size, float w) { + float32x4_t vZero = vdupq_n_f32(0.0f); + float32x4_t vW = vdupq_n_f32(w); + + constexpr int kVecSizeInFloat = sizeof(float32x4_t) / sizeof(float); + + if (size < kVecSizeInFloat) { + for (int i = 0; i < size; ++i) { + float v = in[i]; + out[i] = v > 0 ? v : v * w; + } + + return; + } + + // We want to load aligned from the input, but assume the output is unaligned + int prologue = + kVecSizeInFloat - + // remainder in floats + (((uintptr_t) in) % (sizeof(float32x4_t))) / sizeof(float); + + int i = 0; + + // Prologue loop + for (; i < prologue; ++i) { + float v = in[i]; + out[i] = v > 0 ? v : v * w; + } + + // The loop is manually unrolled by 6; seems to be the limit for + // armv7 to avoid register spills + constexpr int kUnroll = 6; + constexpr int kFloatsPerLoop = kUnroll * kVecSizeInFloat; + + int remainder = size - prologue; + int vectorizable = prologue + (remainder / kFloatsPerLoop) * kFloatsPerLoop; + + for (; i < vectorizable; i += kFloatsPerLoop) { + float32x4_t v0 = vld1q_f32_aligned(in + i + 0); + float32x4_t v1 = vld1q_f32_aligned(in + i + 4); + float32x4_t v2 = vld1q_f32_aligned(in + i + 8); + float32x4_t v3 = vld1q_f32_aligned(in + i + 12); + float32x4_t v4 = vld1q_f32_aligned(in + i + 16); + float32x4_t v5 = vld1q_f32_aligned(in + i + 20); + + uint32x4_t gz0 = vcgtq_f32(v0, vZero); + uint32x4_t gz1 = vcgtq_f32(v1, vZero); + uint32x4_t gz2 = vcgtq_f32(v2, vZero); + uint32x4_t gz3 = vcgtq_f32(v3, vZero); + uint32x4_t gz4 = vcgtq_f32(v4, vZero); + uint32x4_t gz5 = vcgtq_f32(v5, vZero); + + float32x4_t v0neg = vmulq_f32(v0, vW); + float32x4_t v1neg = vmulq_f32(v1, vW); + float32x4_t v2neg = vmulq_f32(v2, vW); + float32x4_t v3neg = vmulq_f32(v3, vW); + float32x4_t v4neg = vmulq_f32(v4, vW); + float32x4_t v5neg = vmulq_f32(v5, vW); + + // v0 > 0 ? v0 : v0 * w + v0 = vbslq_f32(gz0, v0, v0neg); + v1 = vbslq_f32(gz1, v1, v1neg); + v2 = vbslq_f32(gz2, v2, v2neg); + v3 = vbslq_f32(gz3, v3, v3neg); + v4 = vbslq_f32(gz4, v4, v4neg); + v5 = vbslq_f32(gz5, v5, v5neg); + + vst1q_f32(out + i + 0, v0); + vst1q_f32(out + i + 4, v1); + vst1q_f32(out + i + 8, v2); + vst1q_f32(out + i + 12, v3); + vst1q_f32(out + i + 16, v4); + vst1q_f32(out + i + 20, v5); + } + + for (; i < size; ++i) { + float v = in[i]; + out[i] = v > 0 ? v : v * w; + } +} + +} +#endif // __ARM_NEON__ + +template <> +bool PReluOp::RunOnDevice() { + const auto& X = Input(0); + const auto& W = Input(1); + auto* Y = Output(0); + Y->ResizeLike(X); + const auto* Xdata = X.template data(); + const auto* Wdata = W.template data(); + auto* Ydata = Y->template mutable_data(); + + const auto C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(X.ndim() - 1); + const auto C_shared = (W.size() == 1); + + if (!C_shared) { + CAFFE_ENFORCE_EQ(C, W.size()); + } + + if (C_shared) { +#ifdef __ARM_NEON__ + // The function is completely pointwise + runNeonPrelu(Ydata, Xdata, X.size(), Wdata[0]); +#else + ConstEigenVectorMap Xvec(Xdata, X.size()); + EigenVectorMap Yvec(Ydata, Y->size()); + Yvec = Xvec.cwiseMax(0.f) + Xvec.cwiseMin(0.f) * Wdata[0]; + return true; +#endif // __ARM_NEON__ + } + + // non-shared case. + switch (order_) { + case StorageOrder::NCHW: { + const auto N = X.dim(0); + const auto dim = X.size_from_dim(2); + +#ifdef __ARM_NEON__ + // Pointwise for each channel + for (int n = 0; n < N; ++n) { + for (int c = 0; c < C; ++c) { + runNeonPrelu(Ydata + (n * C + c) * dim, + Xdata + (n * C + c) * dim, + dim, Wdata[c]); + } + } +#else + int nc = 0; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < C; ++c) { + ConstEigenVectorMap Xvec(Xdata + nc * dim, dim); + EigenVectorMap(Ydata + nc * dim, dim) = + Xvec.cwiseMax(0.f) + Xvec.cwiseMin(0.f) * Wdata[c]; + nc++; + } + } +#endif + break; + } + case StorageOrder::NHWC: { + // Lay out matrix as (NHW, C) and multiply by C + const auto NHW = X.size() / C; + ConstEigenArrayMap Xmat(Xdata, C, NHW); + ConstEigenVectorArrayMap Wvec(Wdata, C); + EigenArrayMap Ymat(Ydata, C, NHW); + Ymat = (Xmat > 0).select(Xmat, Xmat.colwise() * Wvec); + break; + } + default: + CAFFE_THROW("Unknown storage order: ", order_); + } + return true; +} + +template <> +bool PReluGradientOp::RunOnDevice() { + auto& Y = Input(0); + auto& dY = Input(1); + auto& X = Input(2); + auto& W = Input(3); + + CAFFE_ENFORCE(&Y != &X, "Cannot backpropagate through an in-place PReLU"); + auto* dX = Output(0); + auto* dW = Output(1); + + DCHECK_GT(Y.size(), 0); + DCHECK_EQ(dY.size(), Y.size()); + dX->ResizeLike(Y); + dW->ResizeLike(W); + + const auto C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(X.ndim() - 1); + const auto C_shared = (W.size() == 1); + + const float* Ydata = Y.data(); + const float* dYdata = dY.data(); + const float* Xdata = X.data(); + const float* Wdata = W.data(); + float* dXdata = dX->mutable_data(); + float* dWdata = dW->mutable_data(); + + // non-shared case. + switch (order_) { + case StorageOrder::NCHW: { + const auto dim = X.size_from_dim(2); + const auto div_factor = C_shared ? C : 1; + for (auto c = 0; c < W.size(); ++c) { + dWdata[c] = 0; + } + + for (int i = 0; i < Y.size(); ++i) { + if (Xdata[i] <= 0) { + int c = (i / dim) % C / div_factor; + dWdata[c] += Ydata[i] * Xdata[i]; + } + } + + for (int i = 0; i < Y.size(); ++i) { + if (Xdata[i] > 0) { + dXdata[i] = dYdata[i]; + } else { + int c = (i / dim) % C / div_factor; + dXdata[i] = Wdata[c] * dYdata[i]; + } + } + break; + } + case StorageOrder::NHWC: { + const auto NHW = X.size() / C; + ConstEigenVectorArrayMap Wvec(Wdata, W.size()); + EigenVectorArrayMap dWvec(dWdata, dW->size()); + + ConstEigenArrayMap Ymat(Ydata, C, NHW); + ConstEigenArrayMap dYmat(dYdata, C, NHW); + ConstEigenArrayMap Xmat(Xdata, C, NHW); + EigenArrayMap dXmat(dXdata, C, NHW); + + if (C_shared) { + dXmat = (Xmat > 0).select(dYmat, dYmat * Wdata[0]); + dWdata[0] = + (Xmat > 0) + .select( + Xmat.cwiseMin(0.0f), // zero gradients on the 'if' path. + Ymat * Xmat) + .sum(); + } else { + dXmat = (Xmat > 0).select(dYmat, dYmat.colwise() * Wvec); + dWvec = (Xmat > 0) + .select( + Xmat.cwiseMin(0.0f), // zero gradients on the 'if' path. + Ymat * Xmat) + .rowwise() + .sum(); + } + break; + } + default: + CAFFE_THROW("Unknown storage order: ", order_); + } + + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(PRelu, PReluOp); +REGISTER_CPU_OPERATOR(PReluGradient, PReluGradientOp); + +// Input: X, Slope, output: Y +OPERATOR_SCHEMA(PRelu) + .NumInputs(2) + .NumOutputs(1) + .AllowInplace({{0, 0}}) + .SetDoc(R"DOC( + +PRelu takes input data (Tensor) and slope tensor as input, and produces one +output data (Tensor) where the function `f(x) = slope * x for x < 0`, +`f(x) = x for x >= 0`., is applied to the data tensor elementwise. + +)DOC") + .Input(0, "X", "1D input tensor") + .Input( + 1, + "Slope", + "1D slope tensor. If `Slope` is of size 1, the value is shared" + "across different channels") + .Output(0, "Y", "1D input tensor"); + +// Input: Y, dY, output: dX +OPERATOR_SCHEMA(PReluGradient).NumInputs(4).NumOutputs(2).SetDoc(R"DOC( + +PReluGradient takes both Y and dY and uses this to update dX and dW according +to the chain rule and derivatives of the rectified linear function. + +)DOC"); + +class GetPReluGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { + return SingleGradientDef( + def_.type() + "Gradient", + "", + vector{O(0), GO(0), I(0), I(1)}, + vector{GI(0), GI(1)}); + } +}; +REGISTER_GRADIENT(PRelu, GetPReluGradient); + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/prelu_op.h b/caffe2/operators/prelu_op.h new file mode 100644 index 000000000000..54e1afa54ccf --- /dev/null +++ b/caffe2/operators/prelu_op.h @@ -0,0 +1,40 @@ +#pragma once + +#include "caffe2/core/context.h" +#include "caffe2/core/logging.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class PReluOp final : public Operator { + public: + PReluOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NCHW"))) {} + + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + + protected: + StorageOrder order_; +}; + +template +class PReluGradientOp final : public Operator { + public: + PReluGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NCHW"))) {} + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + + protected: + StorageOrder order_; +}; + +} // namespace caffe2 diff --git a/caffe2/operators/softmax_op.cc b/caffe2/operators/softmax_op.cc index 7bf40dc6e39e..0259508670c2 100644 --- a/caffe2/operators/softmax_op.cc +++ b/caffe2/operators/softmax_op.cc @@ -1,4 +1,5 @@ #include "caffe2/operators/softmax_op.h" +#include "caffe2/operators/softmax_shared.h" namespace caffe2 { @@ -7,9 +8,9 @@ template <> bool SoftmaxOp::RunOnDevice() { auto& X = Input(0); auto* Y = Output(0); - DCHECK_EQ(X.ndim(), 2); - int N = X.dim32(0); - int D = X.dim32(1); + const auto canonical_axis = X.canonical_axis_index(axis_); + const int N = X.size_to_dim(canonical_axis); + const int D = X.size_from_dim(canonical_axis); Y->ResizeLike(X); float* Ydata = Y->mutable_data(); // First, get scales @@ -21,29 +22,8 @@ bool SoftmaxOp::RunOnDevice() { math::Set(D, 1.f, sum_multiplier_.mutable_data(), &context_); } - math::RowwiseMax(N, D, X.data(), scale_.mutable_data(), - &context_); - // Put the intermediate result X - max(X) into Y - context_.template Copy( - X.size(), X.data(), Ydata); - // Subtract the scale - math::Gemm(CblasNoTrans, CblasNoTrans, N, D, 1, - -1, scale_.data(), sum_multiplier_.data(), 1, - Ydata, &context_); - // Exponentiation - math::Exp(Y->size(), Ydata, Ydata, - &context_); - math::Gemv(CblasNoTrans, N, D, 1, Ydata, - sum_multiplier_.data(), 0, - scale_.mutable_data(), &context_); - // Do division - // TODO(Yangqing): maybe implement it more beautifully? - const float* scale = scale_.data(); - for (int i = 0; i < N; ++i) { - for (int j = 0; j < D; ++j) { - Ydata[i * D + j] /= scale[i]; - } - } + + SoftmaxCPU(context_, N, D, X, Ydata, scale_, sum_multiplier_); return true; } @@ -53,11 +33,9 @@ bool SoftmaxGradientOp::RunOnDevice() { auto& Y = Input(0); auto& dY = Input(1); auto* dX = Output(0); - DCHECK_EQ(Y.ndim(), 2); - int N = Y.dim32(0); - int D = Y.dim32(1); - DCHECK_EQ(dY.dim32(0), N); - DCHECK_EQ(dY.dim32(1), D); + const auto canonical_axis = Y.canonical_axis_index(axis_); + const int N = Y.size_to_dim(canonical_axis); + const int D = Y.size_from_dim(canonical_axis); // First, get scales if (scale_.size() != N) { scale_.Resize(N); @@ -67,7 +45,7 @@ bool SoftmaxGradientOp::RunOnDevice() { math::Set(D, 1.f, sum_multiplier_.mutable_data(), &context_); } - dX->Resize(N, D); + dX->ResizeLike(Y); const float* Ydata = Y.data(); const float* dYdata = dY.data(); float* dXdata = dX->mutable_data(); diff --git a/caffe2/operators/softmax_op.cu b/caffe2/operators/softmax_op.cu index ff2cd7d42bc9..482d92cc0e18 100644 --- a/caffe2/operators/softmax_op.cu +++ b/caffe2/operators/softmax_op.cu @@ -91,31 +91,29 @@ __global__ void softmax_gradient_kernel( } } // namespace -// Implementation for the CPU context. +// Implementation for the CUDA context. template <> bool SoftmaxOp::RunOnDevice() { auto& X = Input(0); auto* Y = Output(0); - DCHECK_EQ(X.ndim(), 2); - int N = X.dim32(0); - int D = X.dim32(1); + const auto canonical_axis = X.canonical_axis_index(axis_); + const int N = X.size_to_dim(canonical_axis); + const int D = X.size_from_dim(canonical_axis); Y->ResizeLike(X); softmax_kernel<<>>( D, X.data(), Y->mutable_data()); return true; } -// Implementation for the CPU context. +// Implementation for the CUDA context. template <> bool SoftmaxGradientOp::RunOnDevice() { auto& Y = Input(0); auto& dY = Input(1); auto* dX = Output(0); - DCHECK_EQ(Y.ndim(), 2); - int N = Y.dim32(0); - int D = Y.dim32(1); - DCHECK_EQ(dY.dim32(0), N); - DCHECK_EQ(dY.dim32(1), D); + const auto canonical_axis = Y.canonical_axis_index(axis_); + const int N = Y.size_to_dim(canonical_axis); + const int D = Y.size_from_dim(canonical_axis); dX->ResizeLike(Y); softmax_gradient_kernel<<>>( diff --git a/caffe2/operators/softmax_op.h b/caffe2/operators/softmax_op.h index 81a7b8ce696d..ad958a8bffd3 100644 --- a/caffe2/operators/softmax_op.h +++ b/caffe2/operators/softmax_op.h @@ -11,11 +11,14 @@ namespace caffe2 { template class SoftmaxOp final : public Operator { public: - USE_SIMPLE_CTOR_DTOR(SoftmaxOp); + SoftmaxOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + axis_(OperatorBase::GetSingleArgument("axis", 1)) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: + int axis_; Tensor scale_; Tensor sum_multiplier_; }; @@ -23,11 +26,14 @@ class SoftmaxOp final : public Operator { template class SoftmaxGradientOp final : public Operator { public: - USE_SIMPLE_CTOR_DTOR(SoftmaxGradientOp); + SoftmaxGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + axis_(OperatorBase::GetSingleArgument("axis", 1)) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: + int axis_; Tensor scale_; Tensor sum_multiplier_; }; diff --git a/caffe2/operators/softmax_shared.cc b/caffe2/operators/softmax_shared.cc new file mode 100644 index 000000000000..680849821114 --- /dev/null +++ b/caffe2/operators/softmax_shared.cc @@ -0,0 +1,55 @@ +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +void SoftmaxCPU( + CPUContext& context, + const int N, + const int D, + const Tensor& X, + float* Ydata, + Tensor& scale, + Tensor& sum_multiplier) { + math::RowwiseMax( + N, D, X.data(), scale.mutable_data(), &context); + // Put the intermediate result X - max(X) into Y + context.template Copy( + X.size(), X.data(), Ydata); + // Subtract the max (for nomuerical reasons) + math::Gemm( + CblasNoTrans, + CblasNoTrans, + N, + D, + 1, + -1, + scale.data(), + sum_multiplier.data(), + 1, + Ydata, + &context); + // Exponentiation + math::Exp(N * D, Ydata, Ydata, &context); + math::Gemv( + CblasNoTrans, + N, + D, + 1, + Ydata, + sum_multiplier.data(), + 0, + scale.mutable_data(), + &context); + // Do division + // TODO(Yangqing): maybe implement it more beautifully? + const float* s = scale.data(); + for (int i = 0; i < N; ++i) { + for (int j = 0; j < D; ++j) { + Ydata[i * D + j] /= s[i]; + } + } +} + +} // namespace caffe2 diff --git a/caffe2/operators/softmax_shared.h b/caffe2/operators/softmax_shared.h new file mode 100644 index 000000000000..0910e3a0675a --- /dev/null +++ b/caffe2/operators/softmax_shared.h @@ -0,0 +1,19 @@ +#ifndef CAFFE2_OPERATORS_SOFTMAX_SHARED_H_ +#define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +void SoftmaxCPU( + CPUContext& context, + const int N, + const int D, + const Tensor& X, + float* Ydata, + Tensor& scale, + Tensor& sum_multiplier); +} // namespace caffe2 + +#endif // #define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_ diff --git a/caffe2/operators/softmax_with_loss_op.cc b/caffe2/operators/softmax_with_loss_op.cc new file mode 100644 index 000000000000..602de4c98b1e --- /dev/null +++ b/caffe2/operators/softmax_with_loss_op.cc @@ -0,0 +1,278 @@ +#include "softmax_with_loss_op.h" +#include "softmax_shared.h" + +namespace caffe2 { + +REGISTER_CPU_OPERATOR(SoftmaxWithLoss, SoftmaxWithLossOp); +REGISTER_CPU_OPERATOR( + SoftmaxWithLossGradient, + SoftmaxWithLossGradientOp); + +// Input: X (logits), T (labels); Output: P (probs), Y +OPERATOR_SCHEMA(SoftmaxWithLoss).NumOutputs(2).SetDoc(R"DOC( +Combined Softmax and Cross-Entropy loss operator. +The operator computes the softmax normalized values for each layer in the batch +of the given input, after which cross-entropy loss is computed. This operator is +numerically more stable than separate Softmax and CrossEntropy ops. +The inputs are a 2-D tensor (Tensor) of size +(batch_size x input_feature_dimensions) and tensor of labels (ground truth). +Output is tensor with the probability for each label for each example (N x D) +and averaged loss (scalar). Use parameter spatial=1 to enable spatial softmax. +Spatial softmax also supports special \"don't care\" label (-1) that is ignored +when computing the loss. + +For spatial version additional weight blob can be added as the third input. +)DOC"); +// Input: X, T, P, dY; Output: dX +OPERATOR_SCHEMA(SoftmaxWithLossGradient).NumOutputs(1); + +#define DONT_CARE (-1) + +template <> +bool SoftmaxWithLossOp::RunOnDevice() { + auto& X = Input(0); // Logits + auto& T = Input(1); // Labels / targets + auto* P = Output(0); // Probabilities from softmax + auto* avg_loss = Output(1); // Average loss + int N = X.dim32(0); + int D = X.dim32(1); + + P->ResizeLike(X); + + if (sum_multiplier_.size() != D) { + sum_multiplier_.Resize(D); + math::Set( + D, 1.f, sum_multiplier_.mutable_data(), &context_); + } + + float* Pdata = P->mutable_data(); + + if (!spatial_mode_) { + DCHECK_EQ(X.ndim(), 2); + DCHECK((T.ndim() == 1) || (T.ndim() == 2 && T.dim32(1) == 1)); + DCHECK_EQ(T.dim32(0), N); + + if (sum_multiplier_.size() != D) { + sum_multiplier_.Resize(D); + math::Set( + D, 1.f, sum_multiplier_.mutable_data(), &context_); + } + + Tensor scalef; + scalef.Resize(N); // TOOD: what's the role of scale? + + SoftmaxCPU(context_, N, D, X, Pdata, scalef, sum_multiplier_); + + // Then compute cross entropy + const int* label_data = T.data(); + float loss_sum = 0.0; + for (int i = 0; i < N; ++i) { + CAFFE_ENFORCE( + label_data[i] < D, + "Label seems incorrect: label value larger than number of classes: ", + label_data[i], + " vs ", + D); + float l = -log(std::max(Pdata[i * D + label_data[i]], 1e-20f)); + loss_sum += l; + } + + avg_loss->Resize(vector()); + float* avg_loss_data = avg_loss->mutable_data(); + avg_loss_data[0] = loss_sum * scale_ / N; + } else { + // Spatial mode, compute softmax for each x, y location + DCHECK_EQ(X.ndim(), 4); + DCHECK_EQ(T.ndim(), 3); + + int H = X.dim32(2); + int W = X.dim32(3); + + const float* weights = (InputSize() > 2 ? Input(2).data() : nullptr); + const float* Xdata = X.data(); + + for (int i = 0; i < N; ++i) { + for (int y = 0; y < H; ++y) { + for (int x = 0; x < W; ++x) { + // Subtract max on each cell for numerical reasons + float max_val = (-1e20f); + for (int c = 0; c < D; ++c) { + // TODO optimize + int idx = i * (H * W * D) + c * (H * W) + y * W + x; + max_val = std::max(max_val, Xdata[idx]); + } + + // Exponentiate + float expsum = 0.0f; + for (int c = 0; c < D; ++c) { + int idx = i * (H * W * D) + c * (H * W) + y * W + x; + float expx = exp(Xdata[idx] - max_val); + Pdata[idx] = expx; + expsum += expx; + } + + // Normalize + for (int c = 0; c < D; ++c) { + int idx = i * (H * W * D) + c * (H * W) + y * W + x; + Pdata[idx] /= expsum; + } + } + } + } + + // Compute the avg cross-entropy loss + avg_loss->Resize(vector()); + float* avg_loss_data = avg_loss->mutable_data(); + const int* label_data = T.data(); + + float sum_label_xent = 0.0f; + float total_weight = 0.0; + + for (int y = 0; y < H; y++) { + for (int x = 0; x < W; x++) { + for (int i = 0; i < N; i++) { + int label_idx = i * H * W + y * W + x; + int label = label_data[label_idx]; + if (label != DONT_CARE) { + int idx = i * (H * W * D) + label * (H * W) + y * W + x; + float w = weights ? weights[label_idx] : 1.0; + total_weight += w; + sum_label_xent += -log(std::max(Pdata[idx], 1e-20f)) * w; + } + } + } + } + *avg_loss_data = sum_label_xent / total_weight; + } // if spatial + return true; +} + +template <> +bool SoftmaxWithLossGradientOp::RunOnDevice() { + auto& X = Input(0); // Logits + auto& T = Input(1); // Labels / targets + // Input(2) is weights if given + auto& P = Input(InputSize() - 2); // Probabilities from softmax + auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss + auto* dX = Output(0); + + int N = X.dim32(0); + int D = X.dim32(1); + dX->ResizeLike(X); + DCHECK_EQ(T.dim32(0), N); + + if (!spatial_mode_) { + DCHECK_EQ(X.ndim(), 2); + DCHECK((T.ndim() == 1) || (T.ndim() == 2 && T.dim32(1) == 1)); + + const float* Pdata = P.data(); + float* dX_data = dX->mutable_data(); + const int* label_data = T.data(); + + // Copy softmax probabilities into dX. All but the neuron + // corresponding to the correct label has gradient equaling e(x_j) + // which is the probability under softmax. + context_.Copy(P.size(), Pdata, dX_data); + + // Compute gradient for the matching labels. + for (int i = 0; i < N; ++i) { + int idx = i * D + label_data[i]; + dX_data[idx] = Pdata[idx] - 1.0f; + } + + // Scale by d_avg_loss / N + math::Scale( + dX->size(), + scale_ / N * d_avg_loss.data()[0], + dX->data(), + dX_data, + &context_); + } else { + // Spatial mode, compute softmax for each x, y location + DCHECK_EQ(X.ndim(), 4); + DCHECK_EQ(T.ndim(), 3); + + int H = X.dim32(2); + int W = X.dim32(3); + + const float* weights = (InputSize() > 4 ? Input(2).data() : nullptr); + + const float* Pdata = P.data(); + float* dX_data = dX->mutable_data(); + const int* label_data = T.data(); + + // Copy softmax probabilities into dX. All but the neuron + // corresponding to the correct label has gradient equaling e(x_j) + // which is the probability under softmax. + context_.Copy(P.size(), Pdata, dX_data); + + float total_weight = 0.0f; + for (int y = 0; y < H; ++y) { + for (int x = 0; x < W; ++x) { + for (int i = 0; i < N; ++i) { + int label_idx = i * H * W + y * W + x; + int label = label_data[label_idx]; + + if (label != DONT_CARE) { + int idx = i * (H * W * D) + label * (H * W) + y * W + x; + + dX_data[idx] = (dX_data[idx] - 1.0); + + if (weights != nullptr) { + float weight = weights[label_idx]; + for (int c = 0; c < D; ++c) { + int k = i * (H * W * D) + c * (H * W) + y * W + x; + dX_data[k] *= weight; + } + total_weight += weight; + } else { + total_weight += 1.0; + } + } else { + + // Set gradient to zero for coordinates where we have dont care + for (int c = 0; c < D; ++c) { + int idx = i * (H * W * D) + c * (H * W) + y * W + x; + dX_data[idx] = 0; + } + } + } + } + } + + math::Scale( + dX->size(), + scale_ / total_weight, + dX->data(), + dX_data, + &context_); + math::Scale( + dX->size(), + d_avg_loss.data(), + dX->data(), + dX->mutable_data(), + &context_); + } + return true; +} + +namespace { +class GetSoftmaxWithLossGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { + vector blob_names{ + {I(0), I(1), O(0), GO(1)}, + }; + + // Add weight blob, if given + if (def_.input_size() == 3) { + blob_names.emplace(blob_names.begin() + 2, I(2)); + } + return SingleGradientDef( + "SoftmaxWithLossGradient", "", blob_names, vector{GI(0)}); + } +}; + +REGISTER_GRADIENT(SoftmaxWithLoss, GetSoftmaxWithLossGradient); +} +} // namespace caffe2 diff --git a/caffe2/operators/softmax_with_loss_op.cu b/caffe2/operators/softmax_with_loss_op.cu new file mode 100644 index 000000000000..886d5c3a26b5 --- /dev/null +++ b/caffe2/operators/softmax_with_loss_op.cu @@ -0,0 +1,396 @@ +#include + +#include "caffe2/core/context_gpu.h" +#include "softmax_with_loss_op.h" + +namespace caffe2 { + +namespace { + +__global__ void LabelCrossEntropyKernel( + const int N, const int D, const float* Pdata, const int* labeldata, + float* Ydata) { + CUDA_1D_KERNEL_LOOP(i, N) { + CUDA_KERNEL_ASSERT(labeldata[i] < D); + Ydata[i] = -logf(max(Pdata[i * D + labeldata[i]], FLT_MIN)); + } +} + +__global__ void LabelCrossEntropyGradientKernel( + const int N, const int D, const float* Pdata, const int* labeldata, + float* dXdata) { + CUDA_1D_KERNEL_LOOP(i, N) { + int idx = i * D + labeldata[i]; + dXdata[idx] = Pdata[idx] - 1.; + } +} + +__global__ void RowMaxKernel(const int num, const int D, const float* data, + float* out) { + CUDA_1D_KERNEL_LOOP(index, num) { + float maxval = -FLT_MAX; + for (int d = 0; d < D; ++d) { + maxval = max(data[index * D + d], maxval); + } + out[index] = maxval; + } +} + + +__global__ void SpatialSoftmaxKernel(const int num, const int D, const int W, const int H, + const float* Xdata, float* Pdata) { + CUDA_1D_KERNEL_LOOP(i, num) { + for(int y = 0; y < H; ++y) { + for(int x = 0; x < W; ++x) { + // Subtract max on each cell for numerical reasons + float max_val = -FLT_MAX; + for(int c = 0; c < D; ++c) { + // TODO optimize + int idx = i * (H * W * D) + c * (H * W) + y * W + x; + max_val = max(max_val, Xdata[idx]); + } + + // Exponentiate + float expsum = 0.0f; + for(int c = 0; c < D; ++c) { + int idx = i * (H * W * D) + c * (H * W) + y * W + x; + float expx = exp(Xdata[idx] - max_val); + Pdata[idx] = expx; + expsum += expx; + } + + // Normalize + for(int c=0; ccuda_stream()>>>(N, D, logits, scales); + // Put the intermediate result X - max(X) into Y + context->Copy(size, logits, probs); + // Subtract the scale + math::Gemm(CblasNoTrans, CblasNoTrans, N, D, 1, + -1, scales, sum_multiplier, 1, probs, context); + // Exponentiation + math::Exp(size, probs, probs, context); + // Sum exponentiated values + math::Gemv(CblasNoTrans, N, D, 1, probs, sum_multiplier, + 0, scales, context); + // Normalize + SoftmaxNormalizeKernel<<cuda_stream()>>>( + size, D, probs, scales, probs); +} + +} // namespace + +template<> +bool SoftmaxWithLossOp::RunOnDevice() { + auto& X = Input(0); // Logits + auto& T = Input(1); // Labels / targets + auto* P = Output(0); // Probabilities from softmax + auto* avg_loss = Output(1); // Average loss + int N = X.dim32(0); + int D = X.dim32(1); + P->ResizeLike(X); + + if (!spatial_mode_) { + DCHECK_EQ(X.ndim(), 2); + DCHECK((T.ndim() == 1) || (T.ndim() == 2 && T.dim32(1) == 1)); + DCHECK_EQ(T.dim32(0), N); + + avg_loss->Resize(vector()); + if (losses_.size() != N) { + losses_.Resize(N); + } + if (sum_multiplier_.size() != D) { + sum_multiplier_.Resize(D); + math::Set( + D, 1.f, sum_multiplier_.mutable_data(), &context_); + } + Softmax(N, D, X.data(), T.data(), sum_multiplier_.data(), + losses_.mutable_data(), P->mutable_data(), &context_); + // Compute label xent loss per example + LabelCrossEntropyKernel<<>>( + N, D, P->data(), T.data(), losses_.mutable_data()); + // Sum of all losses + float* avg_loss_data = avg_loss->mutable_data(); + math::Sum( + losses_.size(), losses_.data(), avg_loss_data, &context_); + // Average of input batch size + math::Scale( + 1, scale_ / N, avg_loss_data, avg_loss_data, &context_); + } else { + DCHECK_EQ(X.ndim(), 4); + DCHECK_EQ(T.ndim(), 3); + + int H = X.dim32(2); + int W = X.dim32(3); + + const float* weights = (InputSize() > 2 ? Input(2).data() : NULL); + const float* Xdata = X.data(); + float* Pdata = P->mutable_data(); + + // Softmax for each x,y location + SpatialSoftmaxKernel<<>>( + N, D, W, H, Xdata, Pdata); + + // Cross entropy + avg_loss->Resize(vector()); + float* avg_loss_data = avg_loss->mutable_data(); + math::Set(1, 0.0f, avg_loss_data, &context_); + + const int* label_data = T.data(); + float* total_weight_ptr; + cudaMalloc(&total_weight_ptr, sizeof(float)); + math::Set(1, 0.0f, total_weight_ptr, &context_); + + // TODO: how to set best? + dim3 threadsPerBlock(REDUCTION_KERNEL_THREADS_X, REDUCTION_KERNEL_THREADS_Y); + dim3 numBlocks(1, 1); + SpatialCrossEntropyLossKernel<<>>( + N, D, W, H, P->data(), label_data, weights, + avg_loss_data, total_weight_ptr); + + + // Somewhat awkward scalar passing from device to host + float h_total_weight; + cudaMemcpyAsync(&h_total_weight, total_weight_ptr, sizeof(float), + cudaMemcpyDeviceToHost, context_.cuda_stream()); + cudaFree(total_weight_ptr); + + // Final scaling + math::Scale( + 1, scale_ / h_total_weight, + avg_loss_data, avg_loss_data, &context_); + + } + return true; +} + + +template<> +bool SoftmaxWithLossGradientOp::RunOnDevice() { + auto& X = Input(0); // Logits + auto& T = Input(1); // Labels / targets + // Input(2) is weights, if given + auto& P = Input(InputSize() - 2); // Probabilities from softmax + auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss + auto* dX = Output(0); + int N = X.dim32(0); + int D = X.dim32(1); + dX->ResizeLike(X); + + if (!spatial_mode_) { + DCHECK_EQ(X.ndim(), 2); + DCHECK((T.ndim() == 1) || (T.ndim() == 2 && T.dim32(1) == 1)); + DCHECK_EQ(T.dim32(0), N); + // Copy softmax probabilities into dX + context_.Copy( + P.size(), P.data(), dX->mutable_data()); + // Subtract 1 from labeled positions + LabelCrossEntropyGradientKernel<<>>( + N, D, P.data(), T.data(), dX->mutable_data()); + // Scale by d_avg_loss / N + math::Scale( + dX->size(), scale_ / N, dX->data(), + dX->mutable_data(), &context_); + math::Scale( + dX->size(), d_avg_loss.data(), dX->data(), + dX->mutable_data(), &context_); + } else { + // Spatial mode, compute softmax for each x, y location + DCHECK_EQ(X.ndim(), 4); + DCHECK_EQ(T.ndim(), 3); + + int H = X.dim32(2); + int W = X.dim32(3); + dX->ResizeLike(X); + + const float* weights = (InputSize() > 4 ? Input(2).data() : NULL); + const float* Pdata = P.data(); + float* dX_data = dX->mutable_data(); + const int* label_data = T.data(); + const float* d_avg_loss_data = d_avg_loss.data(); + + // Copy softmax probabilities into dX. All but the neuron + // corresponding to the correct label has gradient equaling e(x_j) + // which is the probability under softmax. + context_.Copy(P.size(), Pdata, dX_data); + + // TODO: how to set best? + dim3 threadsPerBlock(REDUCTION_KERNEL_THREADS_X, REDUCTION_KERNEL_THREADS_Y); + dim3 numBlocks(1, 1); + + float* total_weight_ptr; + cudaMalloc(&total_weight_ptr, sizeof(float)); + math::Set(1, 0.0f, total_weight_ptr, &context_); + + SpatialSoftmaxLossGradientKernel<<>>( + N, D, W, H, label_data, weights, dX_data, + total_weight_ptr); + + // Somewhat awkward scalar passing from device to host + float h_total_weight; + cudaMemcpyAsync(&h_total_weight, total_weight_ptr, sizeof(float), + cudaMemcpyDeviceToHost, context_.cuda_stream()); + cudaFree(total_weight_ptr); + + // Final scaling + math::Scale( + dX->size(), + scale_ / h_total_weight, + dX->data(), + dX->mutable_data(), &context_); + math::Scale( + dX->size(), + d_avg_loss.data(), + dX->data(), + dX->mutable_data(), &context_); + } + return true; +} + + +namespace { +REGISTER_CUDA_OPERATOR(SoftmaxWithLoss, + SoftmaxWithLossOp); +REGISTER_CUDA_OPERATOR(SoftmaxWithLossGradient, + SoftmaxWithLossGradientOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/softmax_with_loss_op.h b/caffe2/operators/softmax_with_loss_op.h new file mode 100644 index 000000000000..7d6270278015 --- /dev/null +++ b/caffe2/operators/softmax_with_loss_op.h @@ -0,0 +1,63 @@ +#ifndef SOFTMAX_WITH_LOSS_OP_H_ +#define SOFTMAX_WITH_LOSS_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/logging.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +template +class SoftmaxWithLossOp final : public Operator { + public: + SoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + scale_(OperatorBase::GetSingleArgument("scale", 1.)), + spatial_mode_(OperatorBase::GetSingleArgument("spatial", 0)), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NCHW"))) { + CAFFE_ENFORCE(scale_ >= 0); + CAFFE_ENFORCE_EQ( + order_, StorageOrder::NCHW, "Only NCHW order is supported right now."); + } + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + + protected: + float scale_; + int spatial_mode_; + StorageOrder order_; + + Tensor losses_; // Per example loss + Tensor sum_multiplier_; // Vector of ones for summing via dot prod +}; + +template +class SoftmaxWithLossGradientOp final : public Operator { + public: + SoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws) + : Operator(def, ws), + scale_(OperatorBase::GetSingleArgument("scale", 1.)), + spatial_mode_(OperatorBase::GetSingleArgument("spatial", 0)), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NCHW"))) { + CAFFE_ENFORCE(scale_ >= 0); + CAFFE_ENFORCE_EQ( + order_, StorageOrder::NCHW, "Only NCHW order is supported right now."); + } + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + + protected: + float scale_; + int spatial_mode_; + Tensor sum_multiplier_; + StorageOrder order_; +}; + +} // namespace caffe2 + +#endif // SOFTMAX_WITH_LOSS_OP_H_ diff --git a/caffe2/operators/softsign_op.cc b/caffe2/operators/softsign_op.cc index 537e92e2d55c..0a8809cf686e 100644 --- a/caffe2/operators/softsign_op.cc +++ b/caffe2/operators/softsign_op.cc @@ -14,10 +14,26 @@ struct SoftsignCPUFunctor { } }; +struct SoftsignGradientCPUFunctor { + template + inline void + Run(const int n, const T* x, const T* dy, T* dx, CPUContext* device_context) { + ConstEigenVectorArrayMap dy_arr(dy, n); + ConstEigenVectorArrayMap x_arr(x, n); + EigenVectorMap(dx, n) = dy_arr * (1 + x_arr.abs()).pow(2).inverse(); + } +}; + namespace { REGISTER_CPU_OPERATOR( Softsign, UnaryElementwiseOp, CPUContext, SoftsignCPUFunctor>); +REGISTER_CPU_OPERATOR( + SoftsignGradient, + BinaryElementwiseOp< + TensorTypes, + CPUContext, + WithoutBroadcast>); OPERATOR_SCHEMA(Softsign) .NumInputs(1) @@ -35,5 +51,39 @@ and output blobs. "The softsign (x/1+|x|) values of the input tensor " "computed element-wise"); +OPERATOR_SCHEMA(SoftsignGradient) + .NumInputs(2) + .NumOutputs(1) + .AllowInplace({{1, 0}}) + .SetDoc(R"DOC( +Calculates the softsign gradient (sgn(x)/(1+|x|)^2) of the given input tensor +element-wise. +)DOC") + .Input(0, "input", "1-D input tensor") + .Input(1, "input", "1-D input tensor") + .Output( + 0, + "output", + "The softsign gradient (sgn(x)/(1+|x|)^2) values of the input tensor " + "computed element-wise"); + +class GetSoftsignGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { + CAFFE_ENFORCE( + I(0) != O(0), + "Cannot compute softsign gradient " + "if you choose to do an in-place calculation."); + + return SingleGradientDef( + "SoftsignGradient", + "", + vector{I(0), GO(0)}, + vector{GI(0)}); + } +}; + +REGISTER_GRADIENT(Softsign, GetSoftsignGradient); + } // namespace } // namespace caffe2 diff --git a/caffe2/operators/softsign_op.cu b/caffe2/operators/softsign_op.cu index 2258c3aef4c7..e7cd2e1d4880 100644 --- a/caffe2/operators/softsign_op.cu +++ b/caffe2/operators/softsign_op.cu @@ -12,6 +12,14 @@ __global__ void SoftsignKernel(const int N, const T* X, T* Y) { } } +template +__global__ void SoftsignGradientKernel(const int N, const T* x, const T* dy, + T* dx) { + CUDA_1D_KERNEL_LOOP(i, N) { + dx[i] = dy[i] / pow(1 + abs(x[i]), 2); + } +} + struct SoftsignCUDAFunctor { template inline void @@ -23,8 +31,18 @@ struct SoftsignCUDAFunctor { device_context->cuda_stream()>>>(n, x, y); return; } - inline bool InplaceAllowed() { - return true; +}; + +struct SoftsignGradientCUDAFunctor { + template + inline void + Run(const int n, const T* x, const T* dy, T* dx, CUDAContext* device_context) { + SoftsignGradientKernel<<< + CAFFE_GET_BLOCKS(n), + CAFFE_CUDA_NUM_THREADS, + 0, + device_context->cuda_stream()>>>(n, x, dy, dx); + return; } }; @@ -32,5 +50,8 @@ namespace { REGISTER_CUDA_OPERATOR( Softsign, UnaryElementwiseOp, CUDAContext, SoftsignCUDAFunctor>); +REGISTER_CUDA_OPERATOR( + SoftsignGradient, + BinaryElementwiseOp, CUDAContext, WithoutBroadcast>); } // namespace } // namespace caffe2 diff --git a/caffe2/operators/spatial_batch_norm_op.cc b/caffe2/operators/spatial_batch_norm_op.cc index 2378b9008ad6..679d2a09193f 100644 --- a/caffe2/operators/spatial_batch_norm_op.cc +++ b/caffe2/operators/spatial_batch_norm_op.cc @@ -75,11 +75,13 @@ bool SpatialBNOp::RunOnDevice() { // Check if they are initialized if (!running_mean->size()) { running_mean->Resize(C); - EigenVectorArrayMap(running_mean->mutable_data(), C) = 0; + EigenVectorArrayMap running_mean_map(running_mean->mutable_data(), C); + running_mean_map.setZero(); } if (!running_var->size()) { running_var->Resize(C); - EigenVectorArrayMap(running_var->mutable_data(), C) = 0; + EigenVectorArrayMap running_var_map(running_var->mutable_data(), C); + running_var_map.setZero(); } EigenVectorArrayMap running_mean_arr( running_mean->mutable_data(), C); diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index 4e3591870d44..9397cb50e86b 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -15,6 +15,8 @@ REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp); REGISTER_CPU_OPERATOR( ScatterWeightedSum, ScatterWeightedSumOp); +REGISTER_CPU_OPERATOR(Max, MaxOp); +REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp); REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp); // From whatever the current context, ensure the output is TensorCPU REGISTER_CPU_OPERATOR( @@ -74,7 +76,9 @@ When the second input is absent, an extra argument `shape` must be specified. It outputs the reshaped tensor as well as the original shape. At most one dimension of the new shape can be -1. In this case, the value is -inferred from the size of the tensor and the remaining dimensions. +inferred from the size of the tensor and the remaining dimensions. A dimension +could also be 0, in which case the actual dimension value is going to be copied +from the input tensor. )DOC") .Arg("shape", "New shape") .Input(0, "data", "An input tensor.") @@ -232,6 +236,21 @@ Currently only works on CPU because of access to INDICES. .Output(0, "X_0", "Has to be exactly the same tensor as the input 0") .EnforceInplace({{0, 0}}); +OPERATOR_SCHEMA(Max) + .NumInputs(1, INT_MAX) + .NumOutputs(1) + .AllowInplace({{0, 0}}) + .SetDoc(R"DOC( +Element-wise max of each of the input tensors. The first input tensor can be +used in-place as the output tensor, in which case the max will be done in +place and results will be accumulated in input0. All inputs and outputs must +have the same shape and data type. +)DOC") + .Input(0, "data_0", "First of the input tensors. Can be inplace.") + .Output(0, "max", "Output tensor. Same dimension as inputs."); + +OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX); + OPERATOR_SCHEMA(ScatterAssign) .NumInputs(3) .NumOutputs(1) @@ -588,6 +607,20 @@ SHOULD_NOT_DO_GRADIENT(WeightedSum); SHOULD_NOT_DO_GRADIENT(ScatterWeightedSum); SHOULD_NOT_DO_GRADIENT(ScatterAssign); +class GetMaxGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { + auto gradInputs = vector(); + auto inputs = vector{O(0), GO(0)}; + for (int i = 0; i < def_.input_size(); i++) { + gradInputs.push_back(GI(i)); + inputs.push_back(I(i)); + } + return SingleGradientDef("MaxGradient", "", inputs, gradInputs); + } +}; +REGISTER_GRADIENT(Max, GetMaxGradient); + // TODO(jiayq): Copy is a bit tricky because one need to figure out correctly // where the input lies (e.g. for muji, which gpu). Right now I am marking it // as not gradient ready. diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index b0dedbd905f9..506bebb0f7c2 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -72,7 +72,8 @@ class PrintOp final : public Operator { bool RunOnDevice() override { if (!OperatorBase::InputIsType>(0) && !OperatorBase::InputIsType(0)) { - LOG(INFO) << "Non-tensor input."; + LOG(INFO) << "Blob of type: " + << OperatorBase::Inputs().at(0)->meta().name(); return true; } // special-case empty tensors since they may have no meta() @@ -459,6 +460,83 @@ class ScatterWeightedSumOp : public Operator { } }; +template +class MaxOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(MaxOp); + + bool RunOnDevice() override { + auto& input0 = Input(0); + auto* output = Output(0); + + output->ResizeLike(input0); + output->CopyFrom(input0, &context_); + + if (InputSize() == 1) { + return true; + } + + // Dimension checking + for (int i = 1; i < InputSize(); ++i) { + CAFFE_ENFORCE_EQ( + output->dims(), + Input(i).dims(), + "Description: Input #", + i, + ", input dimension:", + Input(i).dims(), + " should match output dimension: ", + output->dims()); + } + + T* output_data = output->template mutable_data(); +#pragma omp parallel for + for (int i = 1; i < InputSize(); i++) { + auto input_data = Input(i).template data(); + for (int j = 0; j < input0.size(); j++) { + output_data[j] = std::max(output_data[j], input_data[j]); + } + } + + return true; + } +}; + +template +class MaxGradientOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(MaxGradientOp); + + bool RunOnDevice() override { + auto& output = Input(0); + auto& grad_output = Input(1); + const int kInputStartOffset = 2; + + const T* data = output.template data(); + ConstEigenArrayMap output_array( + output.template data(), 1, output.size()); + ConstEigenArrayMap grad_out_array( + grad_output.template data(), 1, grad_output.size()); + + for (int i = 0; i < OutputSize(); i++) { + auto& input = Input(i + kInputStartOffset); + ConstEigenArrayMap input_array( + input.template data(), 1, input.size()); + + auto* grad_input = Output(i); + grad_input->ResizeLike(input); + EigenArrayMap grad_in_array( + grad_input->template mutable_data(), 1, grad_input->size()); + grad_in_array = grad_out_array * + input_array.cwiseEqual(output_array).template cast(); + } + + return true; + } +}; + /** * @brief Update slices of the tensor in-place by overriding. * @@ -744,10 +822,10 @@ class SliceOp : public Operator { auto* starts_data = starts.template data(); auto* ends_data = ends.template data(); - CHECK_EQ(starts.ndim(), 1); - CHECK_EQ(ends.ndim(), 1); - CHECK_LE(data.ndim(), starts.size()); - CHECK_EQ(starts.size(), ends.size()); + CAFFE_ENFORCE_EQ(starts.ndim(), 1); + CAFFE_ENFORCE_EQ(ends.ndim(), 1); + CAFFE_ENFORCE_GE(data.ndim(), starts.size()); + CAFFE_ENFORCE_EQ(starts.size(), ends.size()); std::vector starts_idx(data.ndim()); std::vector ends_idx(data.ndim()); @@ -767,11 +845,11 @@ class SliceOp : public Operator { if (end < 0) { end = data.dims()[i] + 1 + end; } - CHECK_GE(start, 0); - CHECK_GE(end, 0); - CHECK_LT(start, data.dims()[i]); - CHECK_LE(end, data.dims()[i]); - CHECK_GE(end, start); + CAFFE_ENFORCE_GE(start, 0); + CAFFE_ENFORCE_GE(end, 0); + CAFFE_ENFORCE_LT(start, data.dims()[i]); + CAFFE_ENFORCE_LE(end, data.dims()[i]); + CAFFE_ENFORCE_GE(end, start); starts_idx[i] = start; ends_idx[i] = end; dst_sizes[i] = end - start; @@ -780,7 +858,8 @@ class SliceOp : public Operator { int dim = -1; for (int i = 0; i < data.ndim(); ++i) { if (starts_idx[i] > 0 || ends_idx[i] < data.dims()[i]) { - CHECK_EQ(dim, -1) << "Currently only possible to slice in 1 dimension."; + CAFFE_ENFORCE_EQ( + dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; } } @@ -925,6 +1004,13 @@ class ReshapeOp : public Operator { actual_new_shape.assign(shape_data, shape_data + shape.size()); } + // Copy over the dimensions for those that are specified zero. + for (int i = 0; i < actual_new_shape.size(); ++i) { + if (actual_new_shape[i] == 0) { + actual_new_shape[i] = input.dim(i); + } + } + // Checks if the new shape is valid and fills in the missing dimension // specified by -1. // NOTE: At most one dimension can be -1. diff --git a/caffe2/operators/workspace_ops.cc b/caffe2/operators/workspace_ops.cc new file mode 100644 index 000000000000..d9775aa3a775 --- /dev/null +++ b/caffe2/operators/workspace_ops.cc @@ -0,0 +1,42 @@ +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { +namespace { + +class GetAllBlobNamesOp final : public Operator { + public: + GetAllBlobNamesOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + include_shared_(GetSingleArgument("include_shared", true)), + ws_(ws) {} + + bool RunOnDevice() override { + auto* out = Output(0); + const auto& blobs = include_shared_ ? ws_->Blobs() : ws_->LocalBlobs(); + out->Resize(blobs.size()); + std::copy(blobs.begin(), blobs.end(), out->mutable_data()); + return true; + } + + private: + bool include_shared_; + Workspace* ws_; +}; + +REGISTER_CPU_OPERATOR(GetAllBlobNames, GetAllBlobNamesOp); +OPERATOR_SCHEMA(GetAllBlobNames) + .NumInputs(0) + .NumOutputs(1) + .SetDoc(R"DOC( +Return a 1D tensor of strings containing the names +of each blob in the active workspace. +)DOC") + .Arg( + "include_shared", + "(bool, default true) Whether to include blobs " + "inherited from parent workspaces.") + .Output(0, "blob_names", "1D tensor of strings containing blob names."); +SHOULD_NOT_DO_GRADIENT(GetAllBlobNamesOp); +} +} diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 1268d8c9235f..932f34946fb6 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -83,8 +83,9 @@ message Argument { // DeviceType that Caffe2 currently supports. enum DeviceType { - CPU = 0; // In default, we will use CPU. - CUDA = 1; // CUDA, with custom kernels. + CPU = 0; // In default, we will use CPU. + CUDA = 1; // CUDA. + ONLY_FOR_TEST = 20901701; // This device type is only for test. } // Device-specific options. We do not distinguish DeviceOption protos for @@ -93,7 +94,8 @@ enum DeviceType { // not match. message DeviceOption { // [general] Options that need to be carried out before running the execution. - optional DeviceType device_type = 1 [ default = CPU ]; + // optional DeviceType device_type = 1 [ default = CPU ]; + optional int32 device_type = 1 [ default = 0 ]; // 0 is CPU. // [CUDA specific] the cuda gpu id. optional int32 cuda_gpu_id = 2; // [general] The random seed to start the device random number generator with. @@ -224,6 +226,10 @@ message ExecutionStep { // ** It is the user's responsibility to not to put this blob in race conditions. // ** For example when setting this blob in concurrent substeps optional string should_stop_blob = 9; + + // if only_once is true, this step will only be executed once. this ONLY takes + // effect when using should_stop_blob + optional bool only_once = 10; } message PlanDef { diff --git a/caffe2/proto/hsm.proto b/caffe2/proto/hsm.proto index 534be870238b..2e3152cc332e 100644 --- a/caffe2/proto/hsm.proto +++ b/caffe2/proto/hsm.proto @@ -25,6 +25,9 @@ message NodeProto { repeated NodeProto children = 1; // Links to terminal (leaf) nodes repeated int32 word_ids = 2; + optional int32 offset = 3; + optional string name = 4; + repeated float scores = 5; } // Protobuf format to accept hierarchy for hierarchical softmax operator. diff --git a/caffe2/python/_import_c_extension.py b/caffe2/python/_import_c_extension.py index f760f12ec6d2..c4b80a889c2a 100644 --- a/caffe2/python/_import_c_extension.py +++ b/caffe2/python/_import_c_extension.py @@ -29,3 +29,15 @@ with extension_loader.DlopenGuard(): # libcaffe2_python contains a global Workspace that we need to properly delete # when exiting. Otherwise, cudart will cause segfaults sometimes. atexit.register(on_module_exit) # noqa + + +# Add functionalities for the TensorCPU interface. +def _TensorCPU_shape(self): + return tuple(self._shape) + + +def _TensorCPU_reshape(self, shape): + return self._reshape(list(shape)) + +TensorCPU.shape = property(_TensorCPU_shape) # noqa +TensorCPU.reshape = _TensorCPU_reshape # noqa diff --git a/caffe2/python/caffe_translator.py b/caffe2/python/caffe_translator.py index 10802429c7cf..9e9c4f34ddfb 100644 --- a/caffe2/python/caffe_translator.py +++ b/caffe2/python/caffe_translator.py @@ -423,3 +423,45 @@ def TranslateInstanceNorm(layer, pretrained_blobs, is_test): caffe_op.input.extend([output + '_w', output + '_b']) AddArgument(caffe_op, "order", "NCHW") return caffe_op, [weight, bias] + + +@TranslatorRegistry.Register("Eltwise") +def TranslateElementWise(layer, pretrained_blobs, is_test): + param = layer.eltwise_param + # TODO(jiayq): if we have a protobuf that uses this, lift this constraint + # and verify that we can correctly translate. + if len(param.coeff) or param.operation != 1: + raise RuntimeError("This eltwise layer is not yet supported.") + caffe_op = BaseTranslate(layer, "Sum") + return caffe_op, [] + + +@TranslatorRegistry.Register("Scale") +def TranslateScale(layer, pretrained_blobs, is_test): + caffe_op = BaseTranslate(layer, "Mul") + scale_param = layer.scale_param + AddArgument(caffe_op, "axis", scale_param.axis) + AddArgument(caffe_op, "broadcast", True) + if len(caffe_op.input) == 1: + # the scale parameter is in pretrained blobs + if scale_param.num_axes != 1: + raise RuntimeError("This path has not been verified yet.") + output = caffe_op.output[0] + caffe_op.input.append(output + '_w') + weight = utils.NumpyArrayToCaffe2Tensor( + pretrained_blobs[0].flatten(), output + '_w') + return caffe_op, [weight] + elif len(caffe_op.input) == 2: + # TODO(jiayq): find a protobuf that uses this and verify. + raise RuntimeError("This path has not been verified yet.") + else: + raise RuntimeError("Unexpected number of inputs.") + + +@TranslatorRegistry.Register("Reshape") +def TranslateReshape(layer, pretrained_blobs, is_test): + caffe_op = BaseTranslate(layer, "Reshape") + caffe_op.output.append("_" + caffe_op.input[0] + "_dims") + reshape_param = layer.reshape_param + AddArgument(caffe_op, 'shape', reshape_param.shape.dim) + return caffe_op, [] diff --git a/caffe2/python/cnn.py b/caffe2/python/cnn.py index 62a132d6e5a3..83efdf61f8f1 100644 --- a/caffe2/python/cnn.py +++ b/caffe2/python/cnn.py @@ -1,9 +1,12 @@ -from caffe2.python import core +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, scope from caffe2.python.model_helper import ModelHelperBase from caffe2.proto import caffe2_pb2 -import logging - class CNNModelHelper(ModelHelperBase): """A helper model so we can write CNN models more easily, without having to @@ -27,6 +30,24 @@ class CNNModelHelper(ModelHelperBase): "Cannot understand the CNN storage order %s." % self.order ) + def GetWeights(self, namescope=None): + if namescope is None: + namescope = scope.CurrentNameScope() + + if namescope == '': + return self.weights[:] + else: + return [w for w in self.weights if w.GetNameScope() == namescope] + + def GetBiases(self, namescope=None): + if namescope is None: + namescope = scope.CurrentNameScope() + + if namescope == '': + return self.biases[:] + else: + return [b for b in self.biases if b.GetNameScope() == namescope] + def ImageInput( self, blob_in, blob_out, **kwargs ): @@ -233,7 +254,12 @@ class CNNModelHelper(ModelHelperBase): blob_out + '_w', self.param_init_net) bias = core.ScopedBlobReference( blob_out + '_b', self.param_init_net) - self.params.extend([weight, bias]) + + if 'freeze_bias' in kwargs: + self.params.extend([weight]) + else: + self.params.extend([weight, bias]) + self.weights.append(weight) self.biases.append(bias) return op_call([blob_in, weight, bias], blob_out, **kwargs) @@ -419,6 +445,26 @@ class CNNModelHelper(ModelHelperBase): print("DepthConcat is deprecated. use Concat instead.") return self.Concat(blobs_in, blob_out, **kwargs) + def PRelu(self, blob_in, blob_out, num_channels=1, slope_init=None, + **kwargs): + """PRelu""" + slope_init = ( + slope_init if slope_init else ('ConstantFill', {'value': 0.25})) + if self.init_params: + slope = self.param_init_net.__getattr__(slope_init[0])( + [], + blob_out + '_slope', + shape=[num_channels], + **slope_init[1] + ) + else: + slope = core.ScopedBlobReference( + blob_out + '_slope', self.param_init_net) + + self.params.extend([slope]) + + return self.net.PRelu([blob_in, slope], [blob_out]) + def Relu(self, blob_in, blob_out, **kwargs): """Relu.""" if self.use_cudnn: @@ -454,7 +500,7 @@ class CNNModelHelper(ModelHelperBase): self.biases.append(bias) blob_outs = [blob_out, running_mean, running_inv_var, blob_out + "_sm", blob_out + "_siv"] - if kwargs['is_test']: + if 'is_test' in kwargs and kwargs['is_test']: blob_outputs = self.net.SpatialBN( [blob_in, scale, bias, blob_outs[1], blob_outs[2]], [blob_out], order=self.order, **kwargs) @@ -503,9 +549,13 @@ class CNNModelHelper(ModelHelperBase): wd = self.param_init_net.ConstantFill([], 'wd', shape=[1], value=weight_decay) ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0) - for param in self.weights: + for param in self.GetWeights(): # Equivalent to: grad += wd * param - self.net.WeightedSum([self.param_to_grad[param], ONE, param, wd]) + grad = self.param_to_grad[param] + self.net.WeightedSum( + [grad, ONE, param, wd], + grad, + ) @property def CPU(self): diff --git a/caffe2/python/context.py b/caffe2/python/context.py new file mode 100644 index 000000000000..eaf68c2609cb --- /dev/null +++ b/caffe2/python/context.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import threading + +_CONTEXT_MANAGER = threading.local() + + +def context_manager(): + global _CONTEXT_MANAGER + if not hasattr(_CONTEXT_MANAGER, 'obj'): + _CONTEXT_MANAGER.obj = ContextManager() + return _CONTEXT_MANAGER.obj + + +class ContextInfo(object): + def __init__(self, cls, allow_default, arg_name): + self.cls = cls + self.allow_default = allow_default + self.arg_name = arg_name + self._stack = [] + + def enter(self, value): + self._stack.append(value) + + def exit(self, value): + assert len(self._stack) > 0, 'Context %s is empty.' % self.cls + assert self._stack.pop() == value + + def get_active(self, required=True): + if len(self._stack) == 0: + if not required: + return None + assert self.allow_default, ( + 'Context %s is required but none is active.' % self.cls) + self.enter(self.cls()) + return self._stack[-1] + + +class ContextManager(object): + def __init__(self): + self._ctxs = {} + + def register(self, ctx_info): + assert isinstance(ctx_info, ContextInfo) + assert (ctx_info.cls not in self._ctxs), ( + 'Context %s already registered' % ctx_info.cls) + self._ctxs[ctx_info.cls] = ctx_info + + def get(self, cls): + assert cls in self._ctxs, 'Context %s not registered.' % cls + return self._ctxs[cls] + + +def __enter__(self): + if self._prev_enter is not None: + self._prev_enter() + context_manager().get(self._ctx_class).enter(self) + return self + + +def __exit__(self, *args): + context_manager().get(self._ctx_class).exit(self) + if self._prev_exit is not None: + self._prev_exit(*args) + + +@classmethod +def current(cls, value=None, required=True): + return get_active_context(cls, value, required) + + +class define_context(object): + def __init__(self, arg_name=None, allow_default=False): + self.arg_name = arg_name + self.allow_default = allow_default + + def __call__(self, cls): + assert not hasattr(cls, '_ctx_class'), ( + '%s parent class (%s) already defines context.' % ( + cls, cls._ctx_class)) + context_manager().register( + ContextInfo(cls, self.allow_default, self.arg_name)) + cls._prev_enter = cls.__enter__ if hasattr(cls, '__enter__') else None + cls._prev_exit = cls.__exit__ if hasattr(cls, '__exit__') else None + cls._ctx_class = cls + cls.__enter__ = __enter__ + cls.__exit__ = __exit__ + cls.current = current + return cls + + +def get_active_context(cls, val=None, required=True): + ctx_info = context_manager().get(cls) + if val is not None: + assert isinstance(val, cls), ( + 'Wrong context type. Expected: %s, got %s.' % (cls, type(val))) + return val + return ctx_info.get_active(required=required) diff --git a/caffe2/python/control.py b/caffe2/python/control.py index 9514c3a2841c..de629d21f595 100644 --- a/caffe2/python/control.py +++ b/caffe2/python/control.py @@ -17,6 +17,67 @@ from __future__ import unicode_literals from caffe2.python import core +# Used to generate names of the steps created by the control functions. +# It is actually the internal index of these steps. +_current_idx = 1 +_used_step_names = set() + + +def _get_next_step_name(control_name, base_name): + global _current_idx, _used_step_names + concat_name = '%s/%s' % (base_name, control_name) + next_name = concat_name + while next_name in _used_step_names: + next_name = '%s_%d' % (concat_name, _current_idx) + _current_idx += 1 + _used_step_names.add(next_name) + return next_name + + +def _MakeList(input): + """ input is a tuple. + Example: + (a, b, c) --> [a, b, c] + (a) --> [a] + ([a, b, c]) --> [a, b, c] + """ + if len(input) == 0: + raise ValueError( + 'input cannot be empty.') + elif len(input) == 1: + output = input[0] + if not isinstance(output, list): + output = [output] + else: + output = list(input) + return output + + +def _IsNets(nets_or_steps): + if isinstance(nets_or_steps, list): + return all(isinstance(n, core.Net) for n in nets_or_steps) + else: + return isinstance(nets_or_steps, core.Net) + + +def _PrependNets(nets_or_steps, *nets): + nets_or_steps = _MakeList((nets_or_steps,)) + nets = _MakeList(nets) + if _IsNets(nets_or_steps): + return nets + nets_or_steps + else: + return [Do('prepend', nets)] + nets_or_steps + + +def _AppendNets(nets_or_steps, *nets): + nets_or_steps = _MakeList((nets_or_steps,)) + nets = _MakeList(nets) + if _IsNets(nets_or_steps): + return nets_or_steps + nets + else: + return nets_or_steps + [Do('append', nets)] + + def GetConditionBlobFromNet(condition_net): """ The condition blob is the last external_output that must @@ -30,6 +91,39 @@ def GetConditionBlobFromNet(condition_net): # when we create new ops (such as OR of two inputs) return core.BlobReference(condition_net.Proto().external_output[-1]) + +def BoolNet(*blobs_with_bool_value): + """A net assigning constant bool values to blobs. It is mainly used for + initializing condition blobs, for example, in multi-task learning, we + need to access reader_done blobs before reader_net run. In that case, + the reader_done blobs must be initialized. + + Args: + blobs_with_bool_value: one or more (blob, bool_value) pairs. The net will + assign each bool_value to the corresponding blob. + + returns + bool_net: A net assigning constant bool values to blobs. + + Examples: + - BoolNet((blob_1, bool_value_1), ..., (blob_n, bool_value_n)) + - BoolNet([(blob_1, net1), ..., (blob_n, bool_value_n)]) + - BoolNet((cond_1, bool_value_1)) + """ + blobs_with_bool_value = _MakeList(blobs_with_bool_value) + bool_net = core.Net('bool_net') + for blob, bool_value in blobs_with_bool_value: + out_blob = bool_net.ConstantFill( + [], + [blob], + shape=[], + value=bool_value, + dtype=core.DataType.BOOL) + bool_net.AddExternalOutput(out_blob) + + return bool_net + + def NotNet(condition_blob_or_net): """Not of a condition blob or net @@ -109,114 +203,149 @@ def MergeConditionNets(name, condition_nets, relation): return merged_net -def Do(*nets_or_steps): +def CombineConditions(name, condition_nets, relation): + """ + Combine conditions of multi nets into a single condition nets. Unlike + MergeConditionNets, the actual body of condition_nets is not copied into + the combine condition net. + + One example is about multi readers. Each reader net has a reader_done + condition. When we want to check whether all readers are done, we can + use this function to build a new net. + + Args: + name: name of the new condition net. + condition_nets: a list of condition nets. The last external_output + of each condition net must be single bool value. + relation: can be 'And' or 'Or'. + + Returns: + - A new condition net. Its last external output is relation of all + condition_nets. + """ + if not condition_nets: + return None + if not isinstance(condition_nets, list): + raise ValueError('condition_nets must be a list of nets.') + + if len(condition_nets) == 1: + condition_blob = GetConditionBlobFromNet(condition_nets[0]) + condition_net, _ = _CopyConditionBlobNet(condition_blob) + return condition_net + + combined_net = core.Net(name) + for i in range(len(condition_nets)): + curr_cond = GetConditionBlobFromNet(condition_nets[i]) + if i == 0: + last_cond = curr_cond + else: + last_cond = combined_net.__getattr__(relation)( + [last_cond, curr_cond]) + + combined_net.AddExternalOutput(last_cond) + + return combined_net + + +def Do(name, *nets_or_steps): """ Execute the sequence of nets or steps once. Examples: - - Do(net1, net2, ..., net_n) - - Do(list_of_nets) - - Do(step1, step2, ..., step_n) - - Do(list_of_steps) + - Do('myDo', net1, net2, ..., net_n) + - Do('myDo', list_of_nets) + - Do('myDo', step1, step2, ..., step_n) + - Do('myDo', list_of_steps) """ - if len(nets_or_steps) == 0: - raise ValueError( - 'nets_or_steps cannot be empty.') - elif len(nets_or_steps) == 1: - nets_or_steps = nets_or_steps[0] + nets_or_steps = _MakeList(nets_or_steps) + if (len(nets_or_steps) == 1 and isinstance( + nets_or_steps[0], core.ExecutionStep)): + return nets_or_steps[0] else: - nets_or_steps = list(nets_or_steps) - - return core.execution_step('Do', nets_or_steps) + return core.execution_step( + _get_next_step_name('Do', name), nets_or_steps) -def DoParallel(*nets_or_steps): +def DoParallel(name, *nets_or_steps): """ Execute the nets or steps in parallel, waiting for all of them to finish Examples: - - DoParallel(net1, net2, ..., net_n) - - DoParallel(list_of_nets) - - DoParallel(step1, step2, ..., step_n) - - DoParallel(list_of_steps) + - DoParallel('pDo', net1, net2, ..., net_n) + - DoParallel('pDo', list_of_nets) + - DoParallel('pDo', step1, step2, ..., step_n) + - DoParallel('pDo', list_of_steps) """ - if len(nets_or_steps) == 0: - raise ValueError( - 'nets_or_steps cannot be empty.') - elif len(nets_or_steps) == 1: - nets_or_steps = nets_or_steps[0] + nets_or_steps = _MakeList(nets_or_steps) + if (len(nets_or_steps) == 1 and isinstance( + nets_or_steps[0], core.ExecutionStep)): + return nets_or_steps[0] else: - nets_or_steps = list(nets_or_steps) - - return core.execution_step( - 'DoParallel', nets_or_steps, concurrent_substeps=True) + return core.execution_step( + _get_next_step_name('DoParallel', name), + nets_or_steps, + concurrent_substeps=True) -def _StopNet(stop_blob): - stop_net = core.Net('stop_net') - stop_net.ConstantFill( - [], [stop_blob], shape=[], value=True, dtype=core.DataType.BOOL) - return stop_net - - -def _ToExecutionStep(net_or_step): - if isinstance(net_or_step, core.Net): - return Do(net_or_step) - elif isinstance(net_or_step, core.ExecutionStep): - return net_or_step - else: - raise ValueError( - 'net_or_step must be a net or a step.') - - -def _RunOnceIf(condition_blob_or_net, net_or_step): +def _RunOnceIf(name, condition_blob_or_net, nets_or_steps): """ - Execute net_or_step once if condition_blob_or_net evaluates as true. + Execute nets_or_steps once if condition_blob_or_net evaluates as true. If condition_blob_or_net is Net, the condition is its last external_output - that must be a single bool. And this net will be executed before net_or_step - so as to get the condition. + that must be a single bool. And this net will be executed before + nets_or_steps so as to get the condition. """ + condition_not_net, stop_blob = NotNet(condition_blob_or_net) if isinstance(condition_blob_or_net, core.Net): - condition_blob = GetConditionBlobFromNet(condition_blob_or_net) - return Do(Do(condition_blob_or_net), - _RunOnceIf(condition_blob, net_or_step)) + nets_or_steps = _PrependNets( + nets_or_steps, condition_blob_or_net, condition_not_net) + else: + nets_or_steps = _PrependNets(nets_or_steps, condition_not_net) - stop_if_not_net, stop_blob = NotNet(condition_blob_or_net) - stop_net = _StopNet(stop_blob) + def if_step(control_name): + return core.execution_step( + _get_next_step_name(control_name, name), + nets_or_steps, + should_stop_blob=stop_blob, + only_once=True, + ) - return core.execution_step( - '_RunOnceIf', - [Do(stop_if_not_net), _ToExecutionStep(net_or_step), Do(stop_net)], - should_stop_blob=stop_blob) + if _IsNets(nets_or_steps): + bool_net = BoolNet((stop_blob, False)) + return Do(name + '/_RunOnceIf', + bool_net, if_step('_RunOnceIf-inner')) + else: + return if_step('_RunOnceIf') -def _RunOnceIfNot(condition_blob_or_net, net_or_step): +def _RunOnceIfNot(name, condition_blob_or_net, nets_or_steps): """ - Similar to _RunOnceIf() but Execute net_or_step once if + Similar to _RunOnceIf() but Execute nets_or_steps once if condition_blob_or_net evaluates as false. """ if isinstance(condition_blob_or_net, core.Net): condition_blob = GetConditionBlobFromNet(condition_blob_or_net) - return Do(Do(condition_blob_or_net), - _RunOnceIfNot(condition_blob, net_or_step)) - - stop_if_net, stop_blob = _CopyConditionBlobNet(condition_blob_or_net) - stop_net = _StopNet(stop_blob) + nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net) + else: + copy_net, condition_blob = _CopyConditionBlobNet(condition_blob_or_net) + nets_or_steps = _PrependNets(nets_or_steps, copy_net) return core.execution_step( - '_RunOnceIfNot', - [Do(stop_if_net), _ToExecutionStep(net_or_step), Do(stop_net)], - should_stop_blob=stop_blob) + _get_next_step_name('_RunOnceIfNot', name), + nets_or_steps, + should_stop_blob=condition_blob, + only_once=True, + ) -def For(net_or_step, iter_num): +def For(name, nets_or_steps, iter_num): """ - Execute net_or_step iter_num times. + Execute nets_or_steps iter_num times. Args: - net_or_step: an instance of a ExecutionStep or a Net. - iter_num: the number times to execute the net_or_step. + nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or + a list nets. + iter_num: the number times to execute the nets_or_steps. Returns: A ExecutionStep instance. @@ -226,175 +355,215 @@ def For(net_or_step, iter_num): iter_net = core.Net('For-iter') iter_done = iter_net.CountDown([iter_cnt]) - if isinstance(net_or_step, core.Net): - for_step = core.execution_step( - 'For', [iter_net, net_or_step], should_stop_blob=iter_done) - elif isinstance(net_or_step, core.ExecutionStep): - for_step = core.execution_step( - 'For', [Do(iter_net), net_or_step], should_stop_blob=iter_done) - else: - raise ValueError( - 'net_or_step must be a net or a step.') - - return Do(Do(init_net), for_step) + for_step = core.execution_step( + _get_next_step_name('For-inner', name), + _PrependNets(nets_or_steps, iter_net), + should_stop_blob=iter_done) + return Do(name + '/For', + Do(name + '/For-init-net', init_net), + for_step) -def While(condition_blob_or_net, net_or_step): +def While(name, condition_blob_or_net, nets_or_steps): """ - Execute net_or_step when condition_blob_or_net returns true. + Execute nets_or_steps when condition_blob_or_net returns true. Args: condition_blob_or_net: If it is an instance of Net, its last external_output must be a single bool. - net_or_step: an instance of a ExecutionStep or a Net. + nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or + a list nets. Returns: A ExecutionStep instance. """ condition_not_net, stop_blob = NotNet(condition_blob_or_net) if isinstance(condition_blob_or_net, core.Net): - condition_step = Do(condition_blob_or_net, condition_not_net) + nets_or_steps = _PrependNets( + nets_or_steps, condition_blob_or_net, condition_not_net) else: - condition_step = Do(condition_not_net) + nets_or_steps = _PrependNets(nets_or_steps, condition_not_net) - return core.execution_step( - 'While', - [condition_step, _ToExecutionStep(net_or_step)], - should_stop_blob=stop_blob) + def while_step(control_name): + return core.execution_step( + _get_next_step_name(control_name, name), + nets_or_steps, + should_stop_blob=stop_blob, + ) + + if _IsNets(nets_or_steps): + # In this case, while_step has sub-nets: + # [condition_blob_or_net, condition_not_net, nets_or_steps] + # If stop_blob is pre-set to True (this may happen when While() is + # called twice), the loop will exit after executing + # condition_blob_or_net. So we use BootNet to set stop_blob to + # False. + bool_net = BoolNet((stop_blob, False)) + return Do(name + '/While', bool_net, while_step('While-inner')) + else: + return while_step('While') -def Until(condition_blob_or_net, net_or_step): +def Until(name, condition_blob_or_net, nets_or_steps): """ - Similar to While() but execute net_or_step when + Similar to While() but execute nets_or_steps when condition_blob_or_net returns false """ if isinstance(condition_blob_or_net, core.Net): stop_blob = GetConditionBlobFromNet(condition_blob_or_net) - condition_step = Do(condition_blob_or_net) + nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net) else: - copy_net, stop_blob = _CopyConditionBlobNet(condition_blob_or_net) - condition_step = Do(copy_net) + stop_blob = core.BlobReference(str(condition_blob_or_net)) return core.execution_step( - 'Until', - [condition_step, _ToExecutionStep(net_or_step)], + _get_next_step_name('Until', name), + nets_or_steps, should_stop_blob=stop_blob) -def DoWhile(condition_blob_or_net, net_or_step): +def DoWhile(name, condition_blob_or_net, nets_or_steps): """ - Execute net_or_step when condition_blob_or_net returns true. It will execute - net_or_step at least once. + Execute nets_or_steps when condition_blob_or_net returns true. It will + execute nets_or_steps before evaluating condition_blob_or_net. Args: condition_blob_or_net: if it is an instance of Net, tts last external_output must be a single bool. - net_or_step: an instance of a ExecutionStep or a Net. + nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or + a list nets. Returns: A ExecutionStep instance. """ condition_not_net, stop_blob = NotNet(condition_blob_or_net) if isinstance(condition_blob_or_net, core.Net): - condition_step = Do(condition_blob_or_net, condition_not_net) + nets_or_steps = _AppendNets( + nets_or_steps, condition_blob_or_net, condition_not_net) else: - condition_step = Do(condition_not_net) + nets_or_steps = _AppendNets(nets_or_steps, condition_not_net) - return core.execution_step( - 'DoWhile', - [_ToExecutionStep(net_or_step), condition_step], - should_stop_blob=stop_blob) + # If stop_blob is pre-set to True (this may happen when DoWhile() is + # called twice), the loop will exit after executing the first net/step + # in nets_or_steps. This is not what we want. So we use BootNet to + # set stop_blob to False. + bool_net = BoolNet((stop_blob, False)) + return Do(name + '/DoWhile', bool_net, core.execution_step( + _get_next_step_name('DoWhile-inner', name), + nets_or_steps, + should_stop_blob=stop_blob, + )) -def DoUntil(condition_blob_or_net, net_or_step): +def DoUntil(name, condition_blob_or_net, nets_or_steps): """ - Similar to DoWhile() but execute net_or_step when - condition_blob_or_net returns false + Similar to DoWhile() but execute nets_or_steps when + condition_blob_or_net returns false. It will execute + nets_or_steps before evaluating condition_blob_or_net. + + Special case: if condition_blob_or_net is a blob and is pre-set to + true, then only the first net/step of nets_or_steps will be executed and + loop is exited. So you need to be careful about the initial value the + condition blob when using DoUntil(), esp when DoUntil() is called twice. """ - steps = [_ToExecutionStep(net_or_step)] + if not isinstance(condition_blob_or_net, core.Net): + stop_blob = core.BlobReference(condition_blob_or_net) + return core.execution_step( + _get_next_step_name('DoUntil', name), + nets_or_steps, + should_stop_blob=stop_blob) - if isinstance(condition_blob_or_net, core.Net): - steps.append(Do(condition_blob_or_net)) - stop_blob = GetConditionBlobFromNet(condition_blob_or_net) - else: - stop_blob = condition_blob_or_net + nets_or_steps = _AppendNets(nets_or_steps, condition_blob_or_net) + stop_blob = GetConditionBlobFromNet(condition_blob_or_net) - stop_blob = core.BlobReference(str(stop_blob)) - return core.execution_step('DoUntil', steps, should_stop_blob=stop_blob) + # If stop_blob is pre-set to True (this may happen when DoWhile() is + # called twice), the loop will exit after executing the first net/step + # in nets_or_steps. This is not what we want. So we use BootNet to + # set stop_blob to False. + bool_net = BoolNet((stop_blob, False)) + return Do(name + '/DoUntil', bool_net, core.execution_step( + _get_next_step_name('DoUntil-inner', name), + nets_or_steps, + should_stop_blob=stop_blob, + )) -def Switch(*conditions): +def Switch(name, *conditions): """ Execute the steps for which the condition is true. - Each condition is a tuple (condition_blob_or_net, step). + Each condition is a tuple (condition_blob_or_net, nets_or_steps). Note: 1. Multi steps can be executed if their conditions are true. 2. The conditions_blob_or_net (if it is Net) of all steps will be executed once. Examples: - - Switch((cond_1, net_1), (cond_2, net_2), ..., (cond_n, net_n)) - - Switch([(cond_1, net1), (cond_2, net_2), ..., (cond_n, net_n)]) - - Switch((cond_1, net_1)) + - Switch('name', (cond_1, net_1), (cond_2, net_2), ..., (cond_n, net_n)) + - Switch('name', [(cond_1, net1), (cond_2, net_2), ..., (cond_n, net_n)]) + - Switch('name', (cond_1, net_1)) """ - if len(conditions) == 0: - raise ValueError( - 'conditions cannot be empty.') - elif len(conditions) == 1: - conditions = conditions[0] - if not isinstance(conditions, list): - conditions = [conditions] - else: - conditions = list(conditions) - + conditions = _MakeList(conditions) return core.execution_step( - 'Switch', [_RunOnceIf(cond, step) for cond, step in conditions]) + _get_next_step_name('Switch', name), + [_RunOnceIf(name + '/Switch', cond, step) for cond, step in conditions]) -def If(condition_blob_or_net, true_net_or_step, false_net_or_step=None): +def SwitchNot(name, *conditions): + """ + Similar to Switch() but execute the steps for which the condition is False. + """ + conditions = _MakeList(conditions) + return core.execution_step( + _get_next_step_name('SwitchNot', name), + [_RunOnceIfNot(name + '/SwitchNot', cond, step) + for cond, step in conditions]) + + +def If(name, condition_blob_or_net, + true_nets_or_steps, false_nets_or_steps=None): """ condition_blob_or_net is first evaluated or executed. If the condition is - true, true_net_or_step is then executed, otherwise, false_net_or_step + true, true_nets_or_steps is then executed, otherwise, false_nets_or_steps is executed. If condition_blob_or_net is Net, the condition is its last external_output that must be a single bool. And this Net will be executred before both - true/false_net_or_step so as to get the condition. + true/false_nets_or_steps so as to get the condition. """ - if not false_net_or_step: - return _RunOnceIf(condition_blob_or_net, true_net_or_step) + if not false_nets_or_steps: + return _RunOnceIf(name + '/If', + condition_blob_or_net, true_nets_or_steps) if isinstance(condition_blob_or_net, core.Net): condition_blob = GetConditionBlobFromNet(condition_blob_or_net) - return Do(Do(condition_blob_or_net), - If(condition_blob, true_net_or_step, false_net_or_step)) + else: + condition_blob = condition_blob_or_net - condition_blob = condition_blob_or_net - not_net, _ = NotNet(condition_blob) - - return Switch( - (condition_blob, true_net_or_step), - (not_net, false_net_or_step), + return Do( + name + '/If', + _RunOnceIf(name + '/If-true', + condition_blob_or_net, true_nets_or_steps), + _RunOnceIfNot(name + '/If-false', condition_blob, false_nets_or_steps) ) -def IfNot(condition_blob_or_net, true_net_or_step, false_net_or_step=None): +def IfNot(name, condition_blob_or_net, + true_nets_or_steps, false_nets_or_steps=None): """ - If condition_blob_or_net returns false, executes true_net_or_step, - otherwise executes false_net_or_step + If condition_blob_or_net returns false, executes true_nets_or_steps, + otherwise executes false_nets_or_steps """ - if not false_net_or_step: - return _RunOnceIfNot(condition_blob_or_net, true_net_or_step) + if not false_nets_or_steps: + return _RunOnceIfNot(name + '/IfNot', + condition_blob_or_net, true_nets_or_steps) if isinstance(condition_blob_or_net, core.Net): condition_blob = GetConditionBlobFromNet(condition_blob_or_net) - return Do(Do(condition_blob_or_net), - IfNot(condition_blob, true_net_or_step, false_net_or_step)) + else: + condition_blob = condition_blob_or_net - condition_blob = condition_blob_or_net - not_net, _ = NotNet(condition_blob) - - return Switch( - (condition_blob, false_net_or_step), - (not_net, true_net_or_step), + return Do( + name + '/IfNot', + _RunOnceIfNot(name + '/IfNot-true', + condition_blob_or_net, true_nets_or_steps), + _RunOnceIf(name + '/IfNot-false', condition_blob, false_nets_or_steps) ) diff --git a/caffe2/python/control_test.py b/caffe2/python/control_test.py index 066f7a6e32c3..e51aeffa8b04 100644 --- a/caffe2/python/control_test.py +++ b/caffe2/python/control_test.py @@ -28,6 +28,14 @@ class TestControl(test_util.TestCase): [], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64) self.cnt_net_.AddExternalOutput(curr_cnt) + self.cnt_2_net_ = core.Net("cnt-2-net") + self.cnt_2_net_.CountUp([cnt]) + self.cnt_2_net_.CountUp([cnt]) + curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt]) + self.init_net_.ConstantFill( + [], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64) + self.cnt_2_net_.AddExternalOutput(curr_cnt_2) + self.cond_net_ = core.Net("cond-net") cond_blob = self.cond_net_.LT([curr_cnt, const_n]) self.cond_net_.AddExternalOutput(cond_blob) @@ -44,6 +52,10 @@ class TestControl(test_util.TestCase): false_blob = self.false_cond_net_.GT([const_0, const_n]) self.false_cond_net_.AddExternalOutput(false_blob) + self.idle_net_ = core.Net("idle-net") + self.idle_net_.ConstantFill( + [], shape=[], value=0, dtype=core.DataType.INT64) + def CheckNetOutput(self, nets_and_expects): """ Check the net output is expected @@ -54,80 +66,102 @@ class TestControl(test_util.TestCase): net.Proto().external_output[-1]) self.assertEqual(output, expect) + def CheckNetAllOutput(self, net, expects): + """ + Check the net output is expected + expects is a list of bools. + """ + self.assertEqual(len(net.Proto().external_output), len(expects)) + for i in range(len(expects)): + output = workspace.FetchBlob( + net.Proto().external_output[i]) + self.assertEqual(output, expects[i]) + def BuildAndRunPlan(self, step): plan = core.Plan("test") - plan.AddStep(control.Do(self.init_net_)) + plan.AddStep(control.Do('init', self.init_net_)) plan.AddStep(step) self.assertEqual(workspace.RunPlan(plan), True) - def ForLoopTest(self, net_or_step): - step = control.For(net_or_step, self.N_) + def ForLoopTest(self, nets_or_steps): + step = control.For('myFor', nets_or_steps, self.N_) self.BuildAndRunPlan(step) self.CheckNetOutput([(self.cnt_net_, self.N_)]) - def testForLoopWithNet(self): + def testForLoopWithNets(self): self.ForLoopTest(self.cnt_net_) + self.ForLoopTest([self.cnt_net_, self.idle_net_]) def testForLoopWithStep(self): - step = control.Do(self.cnt_net_) + step = control.Do('count', self.cnt_net_) self.ForLoopTest(step) + self.ForLoopTest([step, self.idle_net_]) - def WhileLoopTest(self, net_or_step): - step = control.While(self.cond_net_, net_or_step) + def WhileLoopTest(self, nets_or_steps): + step = control.While('myWhile', self.cond_net_, nets_or_steps) self.BuildAndRunPlan(step) self.CheckNetOutput([(self.cnt_net_, self.N_)]) def testWhileLoopWithNet(self): self.WhileLoopTest(self.cnt_net_) + self.WhileLoopTest([self.cnt_net_, self.idle_net_]) def testWhileLoopWithStep(self): - step = control.Do(self.cnt_net_) + step = control.Do('count', self.cnt_net_) self.WhileLoopTest(step) + self.WhileLoopTest([step, self.idle_net_]) - def UntilLoopTest(self, net_or_step): - step = control.Until(self.not_cond_net_, net_or_step) + def UntilLoopTest(self, nets_or_steps): + step = control.Until('myUntil', self.not_cond_net_, nets_or_steps) self.BuildAndRunPlan(step) self.CheckNetOutput([(self.cnt_net_, self.N_)]) def testUntilLoopWithNet(self): self.UntilLoopTest(self.cnt_net_) + self.UntilLoopTest([self.cnt_net_, self.idle_net_]) def testUntilLoopWithStep(self): - step = control.Do(self.cnt_net_) + step = control.Do('count', self.cnt_net_) self.UntilLoopTest(step) + self.UntilLoopTest([step, self.idle_net_]) - def DoWhileLoopTest(self, net_or_step): - step = control.DoWhile(self.cond_net_, net_or_step) + def DoWhileLoopTest(self, nets_or_steps): + step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps) self.BuildAndRunPlan(step) self.CheckNetOutput([(self.cnt_net_, self.N_)]) def testDoWhileLoopWithNet(self): self.DoWhileLoopTest(self.cnt_net_) + self.DoWhileLoopTest([self.idle_net_, self.cnt_net_]) def testDoWhileLoopWithStep(self): - step = control.Do(self.cnt_net_) + step = control.Do('count', self.cnt_net_) self.DoWhileLoopTest(step) + self.DoWhileLoopTest([self.idle_net_, step]) - def DoUntilLoopTest(self, net_or_step): - step = control.DoUntil(self.not_cond_net_, net_or_step) + def DoUntilLoopTest(self, nets_or_steps): + step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps) self.BuildAndRunPlan(step) self.CheckNetOutput([(self.cnt_net_, self.N_)]) def testDoUntilLoopWithNet(self): self.DoUntilLoopTest(self.cnt_net_) + self.DoUntilLoopTest([self.cnt_net_, self.idle_net_]) def testDoUntilLoopWithStep(self): - step = control.Do(self.cnt_net_) + step = control.Do('count', self.cnt_net_) self.DoUntilLoopTest(step) + self.DoUntilLoopTest([self.idle_net_, step]) def IfCondTest(self, cond_net, expect, cond_on_blob): if cond_on_blob: step = control.Do( - control.Do(cond_net), - control.If(cond_net.Proto().external_output[-1], + 'if-all', + control.Do('count', cond_net), + control.If('myIf', cond_net.Proto().external_output[-1], self.cnt_net_)) else: - step = control.If(cond_net, self.cnt_net_) + step = control.If('myIf', cond_net, self.cnt_net_) self.BuildAndRunPlan(step) self.CheckNetOutput([(self.cnt_net_, expect)]) @@ -143,39 +177,44 @@ class TestControl(test_util.TestCase): def testIfCondFalseOnBlob(self): self.IfCondTest(self.false_cond_net_, 0, True) - def IfElseCondTest(self, cond_net, expect, cond_on_blob): - true_step = control.For(self.cnt_net_, self.N_) - false_step = control.For(self.cnt_net_, 2 * self.N_) + def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob): + if cond_value: + run_net = self.cnt_net_ + else: + run_net = self.cnt_2_net_ if cond_on_blob: step = control.Do( - control.Do(cond_net), - control.If(cond_net.Proto().external_output[-1], - true_step, false_step)) + 'if-else-all', + control.Do('count', cond_net), + control.If('myIfElse', cond_net.Proto().external_output[-1], + self.cnt_net_, self.cnt_2_net_)) else: - step = control.If(cond_net, true_step, false_step) + step = control.If('myIfElse', cond_net, + self.cnt_net_, self.cnt_2_net_) self.BuildAndRunPlan(step) - self.CheckNetOutput([(self.cnt_net_, expect)]) + self.CheckNetOutput([(run_net, expect)]) def testIfElseCondTrueOnNet(self): - self.IfElseCondTest(self.true_cond_net_, self.N_, False) + self.IfElseCondTest(self.true_cond_net_, True, 1, False) def testIfElseCondTrueOnBlob(self): - self.IfElseCondTest(self.true_cond_net_, self.N_, True) + self.IfElseCondTest(self.true_cond_net_, True, 1, True) def testIfElseCondFalseOnNet(self): - self.IfElseCondTest(self.false_cond_net_, 2 * self.N_, False) + self.IfElseCondTest(self.false_cond_net_, False, 2, False) def testIfElseCondFalseOnBlob(self): - self.IfElseCondTest(self.false_cond_net_, 2 * self.N_, True) + self.IfElseCondTest(self.false_cond_net_, False, 2, True) def IfNotCondTest(self, cond_net, expect, cond_on_blob): if cond_on_blob: step = control.Do( - control.Do(cond_net), - control.IfNot(cond_net.Proto().external_output[-1], + 'if-not', + control.Do('count', cond_net), + control.IfNot('myIfNot', cond_net.Proto().external_output[-1], self.cnt_net_)) else: - step = control.IfNot(cond_net, self.cnt_net_) + step = control.IfNot('myIfNot', cond_net, self.cnt_net_) self.BuildAndRunPlan(step) self.CheckNetOutput([(self.cnt_net_, expect)]) @@ -191,27 +230,102 @@ class TestControl(test_util.TestCase): def testIfNotCondFalseOnBlob(self): self.IfNotCondTest(self.false_cond_net_, 1, True) - def IfNotElseCondTest(self, cond_net, expect, cond_on_blob): - true_step = control.For(self.cnt_net_, self.N_) - false_step = control.For(self.cnt_net_, 2 * self.N_) + def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob): + if cond_value: + run_net = self.cnt_2_net_ + else: + run_net = self.cnt_net_ if cond_on_blob: step = control.Do( - control.Do(cond_net), - control.IfNot(cond_net.Proto().external_output[-1], - true_step, false_step)) + 'if-not-else', + control.Do('count', cond_net), + control.IfNot('myIfNotElse', + cond_net.Proto().external_output[-1], + self.cnt_net_, self.cnt_2_net_)) else: - step = control.IfNot(cond_net, true_step, false_step) + step = control.IfNot('myIfNotElse', cond_net, + self.cnt_net_, self.cnt_2_net_) self.BuildAndRunPlan(step) - self.CheckNetOutput([(self.cnt_net_, expect)]) + self.CheckNetOutput([(run_net, expect)]) def testIfNotElseCondTrueOnNet(self): - self.IfNotElseCondTest(self.true_cond_net_, 2 * self.N_, False) + self.IfNotElseCondTest(self.true_cond_net_, True, 2, False) def testIfNotElseCondTrueOnBlob(self): - self.IfNotElseCondTest(self.true_cond_net_, 2 * self.N_, True) + self.IfNotElseCondTest(self.true_cond_net_, True, 2, True) def testIfNotElseCondFalseOnNet(self): - self.IfNotElseCondTest(self.false_cond_net_, self.N_, False) + self.IfNotElseCondTest(self.false_cond_net_, False, 1, False) def testIfNotElseCondFalseOnBlob(self): - self.IfNotElseCondTest(self.false_cond_net_, self.N_, True) + self.IfNotElseCondTest(self.false_cond_net_, False, 1, True) + + def testSwitch(self): + step = control.Switch( + 'mySwitch', + (self.false_cond_net_, self.cnt_net_), + (self.true_cond_net_, self.cnt_2_net_) + ) + self.BuildAndRunPlan(step) + self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)]) + + def testSwitchNot(self): + step = control.SwitchNot( + 'mySwitchNot', + (self.false_cond_net_, self.cnt_net_), + (self.true_cond_net_, self.cnt_2_net_) + ) + self.BuildAndRunPlan(step) + self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)]) + + def testBoolNet(self): + bool_net = control.BoolNet(('a', True)) + step = control.Do('bool', bool_net) + self.BuildAndRunPlan(step) + self.CheckNetAllOutput(bool_net, [True]) + + bool_net = control.BoolNet(('a', True), ('b', False)) + step = control.Do('bool', bool_net) + self.BuildAndRunPlan(step) + self.CheckNetAllOutput(bool_net, [True, False]) + + bool_net = control.BoolNet([('a', True), ('b', False)]) + step = control.Do('bool', bool_net) + self.BuildAndRunPlan(step) + self.CheckNetAllOutput(bool_net, [True, False]) + + def testCombineConditions(self): + # combined by 'Or' + combine_net = control.CombineConditions( + 'test', [self.true_cond_net_, self.false_cond_net_], 'Or') + step = control.Do('combine', + self.true_cond_net_, + self.false_cond_net_, + combine_net) + self.BuildAndRunPlan(step) + self.CheckNetOutput([(combine_net, True)]) + + # combined by 'And' + combine_net = control.CombineConditions( + 'test', [self.true_cond_net_, self.false_cond_net_], 'And') + step = control.Do('combine', + self.true_cond_net_, + self.false_cond_net_, + combine_net) + self.BuildAndRunPlan(step) + self.CheckNetOutput([(combine_net, False)]) + + def testMergeConditionNets(self): + # merged by 'Or' + merge_net = control.MergeConditionNets( + 'test', [self.true_cond_net_, self.false_cond_net_], 'Or') + step = control.Do('merge', merge_net) + self.BuildAndRunPlan(step) + self.CheckNetOutput([(merge_net, True)]) + + # merged by 'And' + merge_net = control.MergeConditionNets( + 'test', [self.true_cond_net_, self.false_cond_net_], 'And') + step = control.Do('merge', merge_net) + self.BuildAndRunPlan(step) + self.CheckNetOutput([(merge_net, False)]) diff --git a/caffe2/python/convnet_benchmarks.py b/caffe2/python/convnet_benchmarks.py index 7414e4715254..27abbc7f3eb0 100644 --- a/caffe2/python/convnet_benchmarks.py +++ b/caffe2/python/convnet_benchmarks.py @@ -630,6 +630,7 @@ def GetArgumentParser(): parser.add_argument("--net_type", type=str, default="dag") parser.add_argument("--num_workers", type=int, default=2) parser.add_argument("--use-nvtx", default=False, action='store_true') + parser.add_argument("--htrace_conf", type=str) return parser @@ -643,7 +644,9 @@ if __name__ == '__main__': workspace.GlobalInit( ['caffe2', '--caffe2_log_level=0'] + - (['--caffe2_use_nvtx'] if args.use_nvtx else [])) + (['--caffe2_use_nvtx'] if args.use_nvtx else []) + + (['--caffe2_htrace_conf=' + args.htrace_conf] + if args.htrace_conf else [])) model_map = { 'AlexNet': AlexNet, 'OverFeat': OverFeat, diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 81147c5eb866..2f5eb7682672 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -8,7 +8,8 @@ from collections import OrderedDict from caffe2.proto import caffe2_pb2 from collections import defaultdict -from caffe2.python import scope, utils, workspace, extension_loader +from caffe2.python import scope, utils, workspace +import numpy as np import caffe2.python._import_c_extension as C @@ -122,6 +123,9 @@ class BlobReference(object): def Net(self): return self._from_net + def GetNameScope(self): + return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1] + def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs): """Internal function that routes the operator generation to the network's __getattr__ function. @@ -156,9 +160,14 @@ class BlobReference(object): op_type, *args, **kwargs) +def ScopedName(name): + """prefix the name with the current scope.""" + return scope.CurrentNameScope() + name + + def ScopedBlobReference(name, *args, **kwargs): """Returns a blob reference with scope prefixed.""" - return BlobReference(scope.NAMESCOPE + name, *args, **kwargs) + return BlobReference(ScopedName(name), *args, **kwargs) def _RectifyInputOutput(blobs, net=None): @@ -166,8 +175,8 @@ def _RectifyInputOutput(blobs, net=None): interface. """ if isinstance(blobs, basestring): - # If blobs is a single string, prepend scope.NAMESCOPE and put it as a - # list. + # If blobs is a single string, prepend scope.CurrentNameScope() + # and put it as a list. # TODO(jiayq): enforce using BlobReference instead of raw strings. return [ScopedBlobReference(blobs, net=net)] elif type(blobs) is BlobReference: @@ -221,12 +230,13 @@ def CreateOperator( operator.control_input.extend([str(i) for i in control_input]) # Set device option: # (1) If device_option is explicitly set, use device_option. - # (2) If not, but scope.DEVICESCOPE is set, then we use scope.DEVICESCOPE. + # (2) If not, but scope.CurrentDeviceScope() is set, + # then we use scope.CurrentDeviceScope(). # (3) Otherwise, do not set device option. if device_option is not None: operator.device_option.CopyFrom(device_option) - elif scope.DEVICESCOPE is not None: - operator.device_option.CopyFrom(scope.DEVICESCOPE) + elif scope.CurrentDeviceScope() is not None: + operator.device_option.CopyFrom(scope.CurrentDeviceScope()) if engine is not None: operator.engine = engine # random seed is defined in the device option, so we need to do special @@ -246,6 +256,14 @@ def CreateOperator( return operator +def CreatePythonOperator(f, inputs, outputs, grad_f=None, *args, **kwargs): + token = C.register_python_op(f) + if grad_f: + C.register_python_gradient_op(token, grad_f) + kwargs["token"] = token + return CreateOperator("Python", inputs, outputs, *args, **kwargs) + + def GetIndexFromGradientList(g_list, name): """A helper function to get the index from a gradient list, None if not matching.""" @@ -665,13 +683,17 @@ class GradientRegistry(object): def GetGradientForOp(cls, op, g_output): try: gradient_ops, g_input = cls._GetGradientForOpCC(op, g_output) - except Exception: + except Exception as e: # Not supported in C++; will try python registration next. + try: gradient_ops, g_input = cls.gradient_registry_[op.type]( op, g_output) except KeyError: - raise KeyError('No gradient registered for op: %s' % op.type) + raise Exception( + "No gradient registered for {}. ".format(op.type) + + "Exception from creating the gradient op: {}.".format(e)) + if gradient_ops is None: return [], g_input if type(gradient_ops) is not list: @@ -785,6 +807,59 @@ def get_op_ids_in_path(ssa, blob_versions, inputs, outputs): return sorted(used_op_ids) +def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None): + """ + Clone the given Net, binding its input schema to the given `inputs` record. + Blob names defined by the net are prepended with the given `prefix`. + + Args: + net: the net to clone + name: the name of the new net + prefix: the prefix to append to local blobs + blob_remap: (optional) dict with additional blob name remapping. + inputs: (optional) input record that will provide actual input + values for the cloned net. Must be compatible with the + net's input schema. + Returns: + Tuple (cloned_net, blob_remap) + clone_net: the cloned Net + blob_remap: a map from original blob names into remapped blob names + """ + from caffe2.python import schema + assert isinstance(net, Net) + if blob_remap is None: + blob_remap = {} + if inputs is not None: + assert isinstance(inputs, schema.Field) + original = net.input_record() + assert original is not None + # TODO(azzolini): improve schema type checking + assert set(original.field_names()) == set(inputs.field_names()), ( + 'Schemas do not match.') + original_mapping = dict(zip(original.field_names(), + original.field_blobs())) + for a, b in zip(inputs.field_names(), inputs.field_blobs()): + blob_remap[str(original_mapping[a])] = str(b) + proto = net.Proto() + ssa, blob_versions = get_ssa(proto) + undef_blobs = get_undefined_blobs(ssa) + + for blob in blob_versions.keys(): + if blob in blob_remap: + continue + elif blob in undef_blobs: + blob_remap[blob] = blob + else: + blob_remap[blob] = prefix + blob + return net.Clone(name, blob_remap), blob_remap + + +def _get_blob_ref(blob_name_or_ref): + return ( + blob_name_or_ref if isinstance(input, BlobReference) + else BlobReference(blob_name_or_ref) + ) + class Net(object): _net_names_used = set() operator_registry_ = {} @@ -806,6 +881,9 @@ class Net(object): name_or_proto: If a NetDef is provided, clone it. Otherwise, create an empty net with the given name. """ + self._input_record = None + self._output_record = None + self._attr_dict = defaultdict(list) if type(name_or_proto) is caffe2_pb2.NetDef: proto = name_or_proto # We rae initializing a network by a NetDef. In this case, we will @@ -840,9 +918,76 @@ class Net(object): # make sure that this net name hasn't been used before self._net.name = Net._get_next_net_name(self._net.name) - def __str__(self): + def AppendNet(self, net): + assert isinstance(net, Net) + self.Proto().op.extend(net.Proto().op) + self.Proto().external_input.extend( + [i for i in net.Proto().external_input + if i not in self.Proto().external_input]) + self.Proto().external_output.extend( + [o for o in net.Proto().external_output + if o not in self.Proto().external_output]) + return self + + def LogInfo(self, *msg_or_blobs): + for msg_or_blob in msg_or_blobs: + if not isinstance(msg_or_blob, BlobReference): + blob = self.GivenTensorStringFill( + [], self.NextName('log'), + shape=[], values=[msg_or_blob]) + else: + blob = msg_or_blob + self.Print(blob, []) + + def add_attribute(self, name, obj): + """ + Add `obj` to the list of attributes in this net under the given `name`. + Attributes are user-defined objects and have no pre-defined semantics. + """ + self._attr_dict[name].append(obj) + + def get_attributes(self, name): + """ + Returns the list of attributes in this net for a given `name`. + Attributes are user-defined objects added with `add_attribute'. + """ + return self._attr_dict.get(name, []) + + def Name(self): return self._net.name + def __str__(self): + return self.Name() + + def Const(self, array, blob_out=None, dtype=None): + if isinstance(array, bool): + return self.ConstantFill( + [], + blob_out or 1, + dtype=DataType.BOOL, + value=array) + + if dtype is None: + array = np.array(array) + else: + array = np.array(array, dtype=dtype) + + def do_set(operator): + return operator( + [], + blob_out or 1, + shape=array.shape, + values=array.flatten().tolist()) + + if array.dtype == np.int32: + return do_set(self.GivenTensorIntFill) + elif array.dtype == np.int64: + return do_set(self.GivenTensorInt64Fill) + elif array.dtype == np.str: + return do_set(self.GivenTensorStringFill) + else: + return do_set(self.GivenTensorFill) + def BlobIsDefined(self, blob): """ Returns true if the given BlobReference is produced as output of @@ -925,7 +1070,27 @@ class Net(object): new_proto.op.extend(remap_op(proto.op[op_id]) for op_id in op_id_mask) remap_list(new_proto.external_input) remap_list(new_proto.external_output) - return Net(new_proto) + new_net = Net(new_proto) + + from caffe2.python import schema + if self._input_record: + new_net._input_record = schema.from_blob_list( + self._input_record, + [ + BlobReference(str(blob_remap[str(blob)]), net=new_net) + for blob in self._input_record.field_blobs() + ], + ) + if self._output_record: + new_net._output_record = schema.from_blob_list( + self._output_record, + [ + BlobReference(str(blob_remap[str(blob)]), net=new_net) + for blob in self._output_record.field_blobs() + ], + ) + new_net._attr_dict.update(self._attr_dict) + return new_net def ClonePartial(self, name, inputs, outputs, remap_funcs=None): """ @@ -1051,14 +1216,49 @@ class Net(object): assert input_name not in self._net.external_input, ( 'Net already contains an input named %s' % input_name) self._net.external_input.extend([input_name]) - return ( - input if isinstance(input, BlobReference) - else BlobReference(input_name)) + return _get_blob_ref(input_name) def AddExternalOutput(self, output): assert isinstance(output, BlobReference) assert self.BlobIsDefined(output) self.Proto().external_output.extend([str(output)]) + return output + + @property + def external_inputs(self): + return map(_get_blob_ref, self._net.external_input) + + @property + def external_outputs(self): + return map(_get_blob_ref, self._net.external_output) + + def set_input_record(self, input_record): + from caffe2.python import schema + assert self._input_record is None, ( + 'Input schema cannot be reset') + if not input_record.has_blobs(): + self._input_record = schema.NewRecord(self, input_record) + else: + self._input_record = input_record + for blob in input_record.field_blobs(): + if blob not in self.external_inputs: + self.AddExternalInput(blob) + return self._input_record + + def set_output_record(self, record): + assert self._output_record is None, ( + 'Output record cannot be reset') + for blob in record.field_blobs(): + assert self.BlobIsDefined(blob) + for blob in record.field_blobs(): + self.AddExternalOutput(blob) + self._output_record = record + + def input_record(self): + return self._input_record + + def output_record(self): + return self._output_record def DeduplicateGradientSlices(self, g): assert isinstance(g, GradientSlice) @@ -1115,13 +1315,10 @@ class Net(object): op_type, *args, **kwargs) def Python(self, f, grad_f=None): - with extension_loader.DlopenGuard(): - import caffe2.python.op.python_ops_python as ops_python - RefreshRegisteredOperators() assert(IsOperator('Python')) - token = ops_python.register(f) + token = C.register_python_op(f) if grad_f: - ops_python.register_gradient(token, grad_f) + C.register_python_gradient_op(token, grad_f) return lambda *args, **kwargs: self._CreateAndAddToSelf( 'Python', token=token, *args, **kwargs) @@ -1165,9 +1362,21 @@ def _add_net_to_dict(net_dict, net): class ExecutionStep(object): + _step_names_used = set() + + @staticmethod + def _get_next_step_name(basename): + name = basename + next_idx = 1 + while name in ExecutionStep._step_names_used: + name = basename + '_' + str(next_idx) + next_idx += 1 + ExecutionStep._step_names_used |= set([name]) + return name + def __init__(self, name, nets=None, num_iter=None): self._step = caffe2_pb2.ExecutionStep() - self._step.name = name + self._step.name = name or ExecutionStep._get_next_step_name('step') self._net_dict = OrderedDict() self._is_used = False self._substeps = [] @@ -1180,6 +1389,9 @@ class ExecutionStep(object): if num_iter is not None: self._step.num_iter = num_iter + def get_net(self, name): + return self._net_dict[name] + def Name(self): return self._step.name @@ -1191,7 +1403,6 @@ class ExecutionStep(object): 'Cannot mutate a step that has already been added to a plan/step.') def _notify_is_used(self): - self._assert_can_mutate() self._is_used = True def Proto(self): @@ -1215,6 +1426,10 @@ class ExecutionStep(object): self._assert_can_mutate() self._step.num_iter = num_iter + def SetOnlyOnce(self, only_once): + self._assert_can_mutate() + self._step.only_once = only_once + def SetShouldStopBlob(self, should_stop_blob): assert isinstance(should_stop_blob, BlobReference), ( "expects BlobReference here, got {}".format(type(should_stop_blob))) @@ -1256,6 +1471,30 @@ class ExecutionStep(object): self._step.network.extend([get_net_name(net)]) return self + def get_all_attributes(self, name): + """ + Return the list of all attributes under the given `name`, present in + all of the nets used in this execution step and its children. + """ + objs = [] + for net in self._net_dict.values(): + objs += net.get_attributes(name) + return objs + + +def add_nets_in_order(step, net_list): + proto = step.Proto() + for substep in step.Substeps(): + add_nets_in_order(substep, net_list) + for net in proto.network: + if net not in net_list: + net_list.append(net) + # FIXME(azzolini): This is actually wrong. Report nets should be + # instantiated first since they may run before any substep is run. + # However, curerntly, Reporter depends on this behavior. + if proto.report_net and proto.report_net not in net_list: + net_list.append(proto.report_net) + class Plan(object): def __init__(self, name_or_step): @@ -1290,7 +1529,33 @@ class Plan(object): if not step.HasNets() and not step.HasSubsteps(): return self._plan.execution_step.add().CopyFrom(step.Proto()) - self.AddNets(step.Nets()) + # nets need to be added to the plan in order of usage + net_list = [] + add_nets_in_order(step, net_list) + self.AddNets([step.get_net(n) for n in net_list]) + + def get_all_attributes(self, name): + """ + Return the list of all attributes under the given `name`, present in + all of the nets used in this plan. + """ + objs = [] + for net in self._net_dict.values(): + objs += net.get_attributes(name) + return objs + + +def to_execution_step(step_or_nets, default_name=None): + from caffe2.python.net_builder import NetBuilder + if isinstance(step_or_nets, ExecutionStep): + return step_or_nets + + stop_blob = None + if isinstance(step_or_nets, NetBuilder): + stop_blob = step_or_nets._stop_blob + step_or_nets = step_or_nets.get() + return execution_step( + default_name, step_or_nets, should_stop_blob=stop_blob) def execution_step(default_name, @@ -1299,7 +1564,8 @@ def execution_step(default_name, report_net=None, report_interval=None, concurrent_substeps=None, - should_stop_blob=None): + should_stop_blob=None, + only_once=None): """ Helper for creating an ExecutionStep. - steps_or_nets can be: @@ -1319,38 +1585,29 @@ def execution_step(default_name, if should_stop_blob is None and num_iter is None: num_iter = 1 - def set_step_attr(step): - if should_stop_blob is not None: - step.SetShouldStopBlob(should_stop_blob) - else: - step.SetIter(num_iter) - if concurrent_substeps is not None: - step.SetConcurrentSubsteps(concurrent_substeps) - if report_net is not None: - assert report_interval is not None - step.SetReportNet(report_net, report_interval) - return step + step = ExecutionStep(default_name) + if should_stop_blob is not None: + step.SetShouldStopBlob(should_stop_blob) + if num_iter is not None: + step.SetIter(num_iter) + if only_once is not None: + step.SetOnlyOnce(only_once) + if concurrent_substeps is not None: + step.SetConcurrentSubsteps(concurrent_substeps) + if report_net is not None: + assert report_interval is not None + step.SetReportNet(report_net, report_interval) - if not steps_or_nets: - return ExecutionStep(default_name) if isinstance(steps_or_nets, ExecutionStep): - step = set_step_attr(ExecutionStep(default_name)) step.AddSubstep(steps_or_nets) - return step elif isinstance(steps_or_nets, Net): - step = set_step_attr(ExecutionStep(default_name)) step.AddNet(steps_or_nets) - return step elif isinstance(steps_or_nets, list): - step = set_step_attr(ExecutionStep(default_name)) - for step_or_net in steps_or_nets: - if isinstance(step_or_net, Net): - step.AddNet(step_or_net) - elif isinstance(step_or_net, ExecutionStep): - step.AddSubstep(step_or_net) - else: - raise ValueError('unsupported type {}'.format(step_or_net)) - return step - else: + if all(isinstance(x, Net) for x in steps_or_nets): + map(step.AddNet, steps_or_nets) + else: + map(step.AddSubstep, map(to_execution_step, steps_or_nets)) + elif steps_or_nets: raise ValueError( 'steps_or_nets must be a step, a net, or a list of nets or steps.') + return step diff --git a/caffe2/python/data_parallel_model.py b/caffe2/python/data_parallel_model.py index 17ab2df82f61..58d1187972c6 100644 --- a/caffe2/python/data_parallel_model.py +++ b/caffe2/python/data_parallel_model.py @@ -2,481 +2,381 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from types import FunctionType -from functools import wraps -import six +from collections import OrderedDict +import logging -from caffe2.python import cnn, dyndep, scope, workspace, core +from caffe2.python import model_helper, dyndep, scope, workspace, core from caffe2.proto import caffe2_pb2 dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/nccl:nccl_ops") - -DATAPARALLEL_OPS = [ - "Conv", - "ConvTranspose", - "GroupConv", - "FC", - "FC_Decomp", - "FC_Prune", - "FC_Sparse", - "LRN", - "Dropout", - "MaxPool", - "AveragePool", - "Concat", - "DepthConcat", - "Relu", - "Transpose", - "SpatialBN", - "Accuracy", - "Adam", - "AveragedLoss", - "Cast", - "LabelCrossEntropy", - "LearningRate", - "Print", - "Scale", - "Snapshot", - "Softmax", - "StopGradient", - "Summarize", - "Sum", - "Tanh", - "WeightedSum", - "SquaredL2Distance", -] +log = logging.getLogger("data_parallel_model") +log.setLevel(logging.INFO) -class _GPUDataParallelMetaClass(type): - """A meta class to patch method in order to distribute them over multiple - GPUs. - """ - _devices = [] +def Parallelize_GPU( + model_helper_obj, + input_builder_fun, + forward_pass_builder_fun, + param_update_builder_fun, + devices=range(0, workspace.NumCudaDevices()), + mpi_comm=None, + all_reduce_engine=None, +): + ''' + Function to create a model that can run on many GPUs. + model_helper_obj: an object of ModelHelperBase, such as CNNModelHelper + input_builder_fun: + Function that adds the input operators + Note: Remember to instantiate reader outside of this + function so all GPUs share same reader object. + Signature: input_builder_fun(model) + forward_pass_builder_fun: + Function to add the operators to the model. + Must return list of loss-blob references that + are used to build the gradient. + Signature: forward_pass_builder_fun(model) + param_update_builder_fun: + Function that adds operators that are run after + gradient update, such as updating the weights and + weight decaying. + Signature: param_update_builder_fun(model) + devices: List of GPU ids, such as [0, 1, 2, 3] + mpi_comm: MPI communicator object if distribuetd computation + is being used. Use SetupMPICluster() function to + create. Default is None. + all_reduce_engine For MPI reduce: RDMA_IBVERBS, RDMA_TCP, or MPI - @staticmethod - def _data_parallel_wrapper(op): - @wraps(op) - def wrapped(cls, blob_in, blob_out, *args, **kwargs): - # Helpers to extract a device specific blob or a global blob - def self_or_item(d, key): - if isinstance(d, dict): - assert key in d - return d[key] - return d + ''' + log.info("Parallelizing model for devices: {}".format(devices)) + mpi_workers = 8 if mpi_comm is None else 0 # best-guess + model_helper_obj.net.Proto().num_workers = len(devices) * 2 + mpi_workers + model_helper_obj.net.Proto().type = 'dag' - def get_input(gpu_id): - if isinstance(blob_in, list): - return [self_or_item(blob, gpu_id) for blob in blob_in] - return self_or_item(blob_in, gpu_id) + # Store some information in the model -- a bit ugly + model_helper_obj._devices = devices + model_helper_obj._mpi_comm = mpi_comm + model_helper_obj._grad_names = [] - def get_output(gpu_id): - return self_or_item(blob_out, gpu_id) + assert isinstance(model_helper_obj, model_helper.ModelHelperBase) + assert model_helper_obj.params == [], "Model needs to be empty" - # If we have explicit device scope, we do not parallelize - if cls.explicit_scope(): - return op( - cls, - blob_in, - blob_out, - *args, - **kwargs) + if mpi_comm is not None: + assert all_reduce_engine in ['MPI', 'RDMA_IBVERBS', 'RDMA_TCP'] - devices = _GPUDataParallelMetaClass._devices - results = {} - for gpu_id in devices: - with core.NameScope("gpu_{}".format(gpu_id)): - device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id) - with core.DeviceScope(device): - result = op( - cls, - get_input(gpu_id), - get_output(gpu_id), - *args, - **kwargs) - results[gpu_id] = result - return results + # Add input and model + log.info("Create input and model training operators") - return wrapped + losses_by_gpu = {} + for device in devices: + device_opt = core.DeviceOption(caffe2_pb2.CUDA, device) + with core.DeviceScope(device_opt): + with core.NameScope("gpu_{}".format(device)): + log.info("Model for GPU: {}".format(device)) + input_builder_fun(model_helper_obj) + losses = forward_pass_builder_fun(model_helper_obj) + assert isinstance(losses, list), \ + 'Model builder function must return a list of loss blobs' + for loss in losses: + assert isinstance(loss, core.BlobReference), \ + 'Model builder func must return a list of loss blobs' - def __new__(meta, classname, bases, class_dict): - assert len(bases) == 1, "Expects only one base class" - base = bases[0] - assert base is cnn.CNNModelHelper, "Base class should be CNNModelHelper" - new_class_dict = {} - for name, attr in base.__dict__.items(): - if name not in DATAPARALLEL_OPS: - continue - attr = _GPUDataParallelMetaClass._data_parallel_wrapper(attr) - new_class_dict[name] = attr - for name, attr in class_dict.items(): - if name in new_class_dict: - continue - if isinstance(attr, FunctionType): - if name in DATAPARALLEL_OPS: - new_class_dict[name] = \ - _GPUDataParallelMetaClass._data_parallel_wrapper(attr) - else: - new_class_dict[name] = attr - return super(_GPUDataParallelMetaClass, meta).__new__( - meta, classname, bases, new_class_dict) + losses_by_gpu[device] = losses + + # Create parameter map + model_helper_obj._device_grouped_blobs =\ + _GroupByDevice(devices, model_helper_obj.params) + model_helper_obj._param_names =\ + model_helper_obj._device_grouped_blobs.keys() + + if (param_update_builder_fun is None): + log.info("Parameter update function not defined --> only forward") + return + + log.info("Adding gradient operators") + _AddGradientOperators(devices, model_helper_obj, losses_by_gpu) + + # Group gradients by device and register to blob lookup + param_to_grad = model_helper_obj.param_to_grad + grads_ordered = [param_to_grad[p] for p in + model_helper_obj.params if p in param_to_grad] + gradients_grouped = _GroupByDevice( + devices, + grads_ordered, + ) + model_helper_obj._device_grouped_blobs.update(gradients_grouped) + model_helper_obj._grad_names = gradients_grouped.keys() + + log.info("Add gradient all-reduces for SyncSGD") + _AllReduceGradients(devices, model_helper_obj, all_reduce_engine, mpi_comm) + + log.info("Post-iteration operators for updating params") + for device in devices: + device_opt = core.DeviceOption(caffe2_pb2.CUDA, device) + with core.DeviceScope(device_opt): + with core.NameScope("gpu_{}".format(device)): + param_update_builder_fun(model_helper_obj) + + # Add initial parameter syncs + log.info("Add initial parameter sync") + if (mpi_comm is not None): + _AddMPIParameterSync( + devices, + model_helper_obj, + model_helper_obj.param_init_net, + mpi_comm, + ) + + _SyncParams(devices, model_helper_obj, model_helper_obj.param_init_net) -@six.add_metaclass(_GPUDataParallelMetaClass) -class GPUDataParallelModel(cnn.CNNModelHelper): - """A helper class that extends CNNModelHelper to support multi GPUs - data parallel training. - """ - def __init__(self, devices, *args, **kwargs): - assert len(devices) >= 1, "Should have at least 1 GPU devices" - assert len(devices) <= workspace.NumCudaDevices(), \ - "Requested # of devices {} is greater than the # of GPUs {}".\ - format(devices, workspace.NumCudaDevices()) - _GPUDataParallelMetaClass._devices = devices - self._devices = devices - self._explicit_scope = False - self._gradient_reduce_all_added = False - self._mpi_comm = None - super(GPUDataParallelModel, self).__init__(*args, **kwargs) +def _AddGradientOperators(devices, model, losses_by_gpu): + def create_grad(lossp): + return model.ConstantFill(lossp, str(lossp) + "_grad", value=1.0) - def explicit_scope(self): - return self._explicit_scope + loss_grad = {} + # Explicitly need to create gradients on each GPU + for gpu_id in devices: + device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id) + with core.DeviceScope(device): + for l in losses_by_gpu[gpu_id]: + lg = create_grad(l) + loss_grad[str(l)] = str(lg) - def _call(self, name, *args, **kwargs): - return super(GPUDataParallelModel, self).__getattr__( - name)(*args, **kwargs) + model.AddGradientOperators(loss_grad) - # TODO(denisy): try out decorators to avoid this code below - def Accuracy(self, *args, **kwargs): - return self._call("Accuracy", *args, **kwargs) - def Adam(self, *args, **kwargs): - return self._call("Adam", *args, **kwargs) +def FinalizeAfterCheckpoint(model, blobs, sync_iter=True): + if not hasattr(model, "_checkpoint_net"): + uniq_blob_names = [stripParamName(p) for p in blobs] - def AveragedLoss(self, *args, **kwargs): - return self._call("AveragedLoss", *args, **kwargs) + # Synchronize to the blob lookup map, as the provided + # blobs might have non-parameters, such as momemtum blobs. + log.info("Creating checkpoint synchronization net") + devices = model.GetDevices() + for name in uniq_blob_names: + if name not in model._device_grouped_blobs: + grouped = { + d: + core.BlobReference("gpu_{}{}{}".format( + d, + scope._NAMESCOPE_SEPARATOR, + name) + ) for d in devices} + model._device_grouped_blobs[name] = grouped - def Cast(self, *args, **kwargs): - return self._call("Cast", *args, **kwargs) + model._checkpoint_net = core.Net("checkpoint_sync_net") + model._checkpoint_net.RunAllOnGPU() - def LabelCrossEntropy(self, *args, **kwargs): - return self._call("LabelCrossEntropy", *args, **kwargs) - - def LearningRate(self, *args, **kwargs): - return self._call("LearningRate", *args, **kwargs) - - def Print(self, *args, **kwargs): - return self._call("Print", *args, **kwargs) - - def Scale(self, *args, **kwargs): - return self._call("Scale", *args, **kwargs) - - def Snapshot(self, *args, **kwargs): - return self._call("Snapshot", *args, **kwargs) - - def Softmax(self, *args, **kwargs): - return self._call("Softmax", *args, **kwargs) - - def StopGradient(self, *args, **kwargs): - return self._call("StopGradient", *args, **kwargs) - - def Sum(self, *args, **kwargs): - return self._call("Sum", *args, **kwargs) - - def Summarize(self, *args, **kwargs): - return self._call("Summarize", *args, **kwargs) - - def Tanh(self, *args, **kwargs): - return self._call("Tanh", *args, **kwargs) - - def WeightedSum(self, *args, **kwargs): - return self._call("WeightedSum", *args, **kwargs) - - def SquaredL2Distance(self, *args, **kwargs): - return self._call("SquaredL2Distance", *args, **kwargs) - - def SetMPIComm(self, mpi_comm): - self._mpi_comm = mpi_comm - - def FinalizeSetup(self): - self.param_init_net.RunAllOnGPU() - self.RunAllOnGPU() - - # If MPI enabled, broadcast params from master - if (self._mpi_comm is not None): - self._AddMPIParameterSync() + if (model._mpi_comm is not None): + _AddMPIParameterSync( + devices, + model, + model._checkpoint_net, + model._mpi_comm, + uniq_blob_names, + ) # Setup sync of initial params - self._SyncInitialParams() + _SyncParams(devices, model, model._checkpoint_net, uniq_blob_names) - def AddGradientOperators(self, params, *args, **kwargs): - def create_grad(param): - return self.ConstantFill(param, str(param) + "_grad", value=1.0) + # Sync ITER -- which is in CPU scope + if sync_iter: + with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)): + for gpu_idx in devices[1:]: + model._checkpoint_net.Copy( + "gpu_{}/ITER".format(devices[0]), + "gpu_{}/ITER".format(gpu_idx), + ) - param_grad = {} - # Explicitly need to create gradients on each GPU - for param in params: - if not isinstance(param, dict): - grad = create_grad(param) - param_grad[str(param)] = str(grad) - else: - for gpu_id in self._devices: - device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id) - with core.DeviceScope(device): - assert gpu_id in param - p = param[gpu_id] - g = create_grad(p) - param_grad[str(p)] = str(g) + # Run the sync + log.info("Run checkpoint net") + workspace.RunNetOnce(model._checkpoint_net) - return super(GPUDataParallelModel, self).AddGradientOperators( - param_grad, *args, **kwargs) - def AddWeightDecay(self, weight_decay): - if weight_decay == 0.0: - return +def _Broadcast(devices, model, net, param): + # TODO(akyrola): replace with NCCLBroadcast when it's working + # Copy params from gpu_0 to other + master_gpu = devices[0] + for gpu_idx in devices[1:]: + device_opt = core.DeviceOption(caffe2_pb2.CUDA, gpu_idx) + with core.DeviceScope(device_opt): + net.Copy( + model._device_grouped_blobs[param][master_gpu], + model._device_grouped_blobs[param][gpu_idx] + ) - assert(weight_decay > 0.0) - self._explicit_scope = True - assert \ - self._gradient_reduce_all_added, \ - "Weight decay must be done after gradient sync between gpus" +def _SyncParams(devices, model, net, unique_param_names=None): + if unique_param_names is None: + unique_param_names = model._param_names - for gpu_id in self._devices: - with core.NameScope("gpu_{}".format(gpu_id)): - device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id) - with core.DeviceScope(device): - wd = self.param_init_net.ConstantFill([], 'wd', shape=[1], - value=weight_decay) - ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], - value=1.0) - # Only update parameters that belong to the current GPU - params = self._CurrentScopeParams() + for param in unique_param_names: + _Broadcast(devices, model, net, param) - # Take only params that are weights - print("Adding weigth-decay for gpu {}.".format(gpu_id)) - gpu_weights = [p for p in params if p in self.weights] - for w in gpu_weights: - # Equivalent to grad -= w * param - grad = self.param_to_grad[w] - self.net.WeightedSum([grad, ONE, w, wd], grad) +def _AddMPIParameterSync(devices, model, net, mpi_comm, uniq_param_names=None): + if uniq_param_names is None: + uniq_param_names = model._param_names - self._explicit_scope = False + device_opt = core.DeviceOption(caffe2_pb2.CUDA, devices[0]) - def _Broadcast(self, net, param): - # TODO(akyrola): replace with NCCLBroadcast when it's working - # Copy params from gpu_0 to other - for gpu_idx in self._devices[1:]: - device_opt = core.DeviceOption(caffe2_pb2.CUDA, gpu_idx) - with core.DeviceScope(device_opt): - net.Copy( - "gpu_{}/{}".format(self._devices[0], param), - "gpu_{}/{}".format(gpu_idx, param) - ) - - def _SyncInitialParams(self): - unique_param_names = set( - stripParamName(p) - for p in self.params + # ITER is in CPU scope :( + with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)): + net.Broadcast( + inputs=[mpi_comm, "gpu_0/ITER"], + outputs=["gpu_0/ITER"], + engine='MPI' ) - self._explicit_scope = True - for param in unique_param_names: - self._Broadcast(self.param_init_net, param) - - self._explicit_scope = False - - def _AddMPIParameterSync(self): - # Sync from master - unique_param_names = set( - stripParamName(p) - for p in self.params - ) - - self._explicit_scope = True - - # Should this be done in GPU 0 scope? - for param_name in unique_param_names: - param = "gpu_{}/{}".format(self._devices[0], param_name) - self.param_init_net.Broadcast( - inputs=[self._mpi_comm, param], + with core.DeviceScope(device_opt): + for param_name in sorted(uniq_param_names): + param = model._device_grouped_blobs[param_name][devices[0]] + net.Broadcast( + inputs=[mpi_comm, param], outputs=[param], engine='MPI' ) - self._explicit_scope = False - def _AllReduceGradients(self): - self._gradient_reduce_all_added = True - if self._mpi_comm is None: - self._AllReduceGradientsSingleHost() - else: - self._AllReduceGradientsWithMPI() +def _AllReduceGradients(devices, model, all_reduce_engine, mpi_comm): + if mpi_comm is None: + _AllReduceGradientsSingleHost(devices, model) + else: + _AllReduceGradientsWithMPI(devices, model, all_reduce_engine, mpi_comm) - def _AllReduceGradientsWithMPI(self): - self._explicit_scope = True - unique_grads_names = set( - stripParamName(grad) - for grad in self.param_to_grad.values() - ) - # Step 1: sum gradients from local GPUs to master GPU - last_out = None - master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, self._devices[0]) +def _AllReduceGradientsWithMPI(devices, model, all_reduce_engine, mpi_comm): + num_workers = model.net.Proto().num_workers + assert num_workers > 1, "Please specify more than 1 worker" - # Note: sorted order to ensure each host puts the operators in - # same order. - for grad_name in sorted(unique_grads_names): - grads_group = [ - grad - for grad in self.param_to_grad.values() - if stripParamName(grad) == grad_name - ] - master_grad = "gpu_{}/{}".format(self._devices[0], grad_name) - assert master_grad in grads_group + # Make list of gradients in reverse order + reverse_ordered_grads = _GetReverseOrderedGrads(model) - # Remark: NCCLReduce does not support in-place modifications - # so we need a temporary gradient blob - reduced_grad = "gpu_{}/{}_red".format( - self._devices[0], - grad_name + # Step 1: sum gradients from local GPUs to master GPU + master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, devices[0]) + reducing_device_opt = master_device_opt + if all_reduce_engine == "RDMA_TCP": + reducing_device_opt = core.DeviceOption(caffe2_pb2.CPU, 0) + + # We need to specify a partial order using control_input to + # ensure progress (since all machines need to do same all reduces + # in parallel) + num_controls = min(4, num_workers - 1) + if all_reduce_engine in ['MPI']: + # With MPI we need to sequentialize + num_controls = 1 + assert num_controls > 0 + + cyclical_controls = [] + counter = 0 + nccl_control_blob = None + + # Note: sorted order to ensure each host puts the operators in + # same order. + for grad_name in reverse_ordered_grads: + master_grad = model._device_grouped_blobs[grad_name][devices[0]] + grads_group = model._device_grouped_blobs[grad_name].values() + + assert master_grad in grads_group + + # Remark: NCCLReduce does not support in-place modifications + # so we need a temporary gradient blob + reduced_grad = str(master_grad) + "_red" + + with core.DeviceScope(master_device_opt): + model.ConstantFill(master_grad, reduced_grad, value=0.0) + + # Temp fix since NCCLReduce does not work + model.net.NCCLAllreduce( + grads_group, + grads_group, + control_input=nccl_control_blob, + ) + nccl_control_blob = grads_group[0] + model.net.Copy(master_grad, reduced_grad) + + # RDMA_TCP works only on CPU context, so we need a temporary + # cpu-bound scratch blob. + if all_reduce_engine == "RDMA_TCP": + with core.DeviceScope(reducing_device_opt): + model.param_init_net.ConstantFill( + [], reduced_grad + "cpu", shape=[1], value=0.0 + ) + with core.DeviceScope(master_device_opt): + # Hack to ensure the cpu-scratch blob is initialized + # prior to running the net. + model.param_init_net.CopyGPUToCPU( + str(master_grad).replace("_grad", ""), reduced_grad + "cpu" + ) + model.net.CopyGPUToCPU(reduced_grad, reduced_grad + "cpu") + reduced_grad = reduced_grad + "cpu" + + control_input = None if len(cyclical_controls) < num_controls \ + else cyclical_controls[counter % num_controls] + + with core.DeviceScope(reducing_device_opt): + # Step 2: allreduce over MPI to all hosts, between master GPUs + model.net.Allreduce( + inputs=[mpi_comm, reduced_grad], + outputs=[reduced_grad], + engine=all_reduce_engine, + control_input=control_input, ) + if reducing_device_opt != master_device_opt: with core.DeviceScope(master_device_opt): - self.ConstantFill(master_grad, reduced_grad, value=0.0) - self.net.NCCLReduce(grads_group, reduced_grad) + model.net.CopyCPUToGPU(reduced_grad, master_grad) + else: + with core.DeviceScope(master_device_opt): + model.net.Copy(reduced_grad, master_grad) - # Step 2: allreduce over MPI to all hosts, between master GPUs - self.net.Allreduce( - inputs=[self._mpi_comm, reduced_grad], - outputs=[master_grad], - engine='MPI', - control_input=None if last_out is None else [last_out], - ) - last_out = master_grad + if len(cyclical_controls) < num_controls: + cyclical_controls.append(reduced_grad) + else: + cyclical_controls[counter % num_controls] = reduced_grad - # Step 3: broadcast locally - self._Broadcast(self.net, grad_name) + counter += 1 - self._explicit_scope = False + # Step 3: broadcast locally + _Broadcast(devices, model, model.net, grad_name) - def _AllReduceGradientsSingleHost(self): - """Performs NCCL AllReduce to distribute gradients to all the GPUs.""" - if len(self._devices) == 1: - return +def _AllReduceGradientsSingleHost(devices, model): + """Performs NCCL AllReduce to distribute gradients to all the GPUs.""" - # Take only params that have gradient associated with them. - unique_grads_names = set( - stripParamName(grad) - for grad in self.param_to_grad.values() - ) + if len(devices) == 1: + return - # Now we need to Allreduce gradients on all the GPUs. - # Pick GPU #0 as a master GPU. - self._explicit_scope = True - master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, self._devices[0]) - with core.DeviceScope(master_device_opt): - # Group by grads for reduce. - for grad_name in unique_grads_names: - grads_group = [ - grad - for grad in self.param_to_grad.values() - if stripParamName(grad) == grad_name - ] - assert len(grads_group) == len(self._devices), \ - "Each GPU from {}, should have a copy of {}.".format( - self._devices, grad_name) - self.NCCLAllreduce(grads_group, grads_group) - self._explicit_scope = False + # Gradients in reverse order + reverse_ordered_grads = _GetReverseOrderedGrads(model) - def _BuildLR(self, base_lr, policy="fixed", **other_lr_params): - """A helper to create learning rate.""" - ITER = self.Iter("ITER") - # There is one interesting thing here: since we are minimizing, we are - # doing "descent" so the learning rate is set to be negative. - LR = self.net.LearningRate( - [ITER], - "LR", - base_lr=base_lr, - policy=policy, - **other_lr_params - ) - return LR + # Now we need to Allreduce gradients on all the GPUs. + # Pick GPU #0 as a master GPU. + master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, devices[0]) + last_out = None + with core.DeviceScope(master_device_opt): + # Group by grads for reduce. + for grad_name in reverse_ordered_grads: + grads_group = model._device_grouped_blobs[grad_name].values() + assert len(grads_group) == len(devices), \ + "Each GPU from {}, should have a copy of {}.".format( + devices, grad_name) + model.NCCLAllreduce( + grads_group, + grads_group, + control_input=last_out, + ) + # last_out is used to serialize the execution of nccls + last_out = grads_group[0] - def _BuildSGD(self, params, base_lr, policy="fixed", **other_lr_params): - """A helper to construct gradient update for SGD.""" - base_lr = base_lr / len(self._devices) - LR = self._BuildLR(base_lr, policy, **other_lr_params) - ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0) - for param in params: - grad = self.param_to_grad[param] - if isinstance(grad, core.GradientSlice): - self.ScatterWeightedSum( - [param, ONE, grad.indices, grad.values, LR], param - ) - else: - self.WeightedSum([param, ONE, grad, LR], param) - def _CurrentScopeParams(self): - return [ - param - for param in self.param_to_grad.keys() - if str(param).startswith(scope.NAMESCOPE) - ] - - def SGD(self, base_lr, policy="fixed", **other_lr_params): - """Adds SGD optimizer to the model.""" - self._AllReduceGradients() - - # Create update params operators. - self._explicit_scope = True - for gpu_id in self._devices: - with core.NameScope("gpu_{}".format(gpu_id)): - device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id) - with core.DeviceScope(device): - # Only update parameters that belong to the current GPU - params = self._CurrentScopeParams() - - # Add optimizer update operators - self._BuildSGD(params, base_lr, policy, **other_lr_params) - self._explicit_scope = False - - def CustomSGD( - self, - paramup_build_fn, - base_lr, - lr_policy, - weight_decay, - **other_lr_pars - ): - """Custom parameter update function""" - self._AllReduceGradients() - - self.AddWeightDecay(weight_decay) - - # Run parameter update on each machine - self._explicit_scope = True - for gpu_id in self._devices: - with core.NameScope("gpu_{}".format(gpu_id)): - device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id) - with core.DeviceScope(device): - LR = self._BuildLR(base_lr, lr_policy, **other_lr_pars) - - params = self._CurrentScopeParams() - paramup_build_fn(self, params, LR) - self._explicit_scope = False - - def ExecOnEachDevice(self, fn, *args, **kwargs): - self._explicit_scope = True - for gpu_id in self._devices: - with core.NameScope("gpu_{}".format(gpu_id)): - device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id) - with core.DeviceScope(device): - fn(self, *args, **kwargs) - - self._explicit_scope = False +def _GetReverseOrderedGrads(model): + ''' + Returns the gradients in reverse order (namespace stripped), + for the optimal synchronization order. + ''' + return list(reversed(model._grad_names)) # A helper function to extract a parameter's name @@ -487,25 +387,60 @@ def stripParamName(param): return name[name.rindex(sep) + 1:] +def _GroupByDevice(devices, params): + ''' + Groups blobs by device, returning a map of [blobname] = {0: BlobRef, 1: ..}. + Returns ordered dictionary, ensuring the original order. + ''' + grouped = OrderedDict() + assert len(params) % len(devices) == 0,\ + "There should be equal number of params per device" + + num_params_per_device = int(len(params) / len(devices)) + + for i, p in enumerate(params): + assert isinstance(p, core.BlobReference), \ + "Param {} is not of type BlobReference".format(p) + + name = stripParamName(p) + gpuid = i // num_params_per_device + assert "gpu_{}/".format(gpuid) in p.GetNameScope(),\ + "Param {} expected to have namescope 'gpu_{}'".format(str(p), gpuid) + + if name not in grouped: + grouped[name] = {} + grouped[name][gpuid] = p + + # Confirm consistency + for j, (p, ps) in enumerate(grouped.items()): + assert \ + len(ps) == len(devices), \ + "Param {} does not have value for each device (only {}: {})".format( + p, len(ps), ps, + ) + # Ensure ordering + assert(ps[devices[0]] == params[j]) + + return grouped + + def SetupMPICluster(num_replicas, role, job_path): from caffe2.python import mpi - print("Initing library") dyndep.InitOpsLibrary('@/caffe2/caffe2/mpi:mpi_ops') - print("Setup peers") + dyndep.InitOpsLibrary('@/caffe2/caffe2/fb/rdma:rdma_ops') + + log.info("MPI: Setup peers") mpi.SetupPeers( replicas=int(num_replicas), role=role, job_path=job_path ) - print("Create mpi_init net") mpi_init_net = core.Net('mpi_init') - print("Create commonworld") mpi_comm = mpi_init_net.CreateCommonWorld( inputs=[], outputs=['comm_world'], - engine='MPI' + engine='MPI', ) - print("Run mpi_init net") workspace.RunNetOnce(mpi_init_net) - print("Finished MPI setup") + log.info("Finished MPI setup") return mpi_comm diff --git a/caffe2/python/data_parallel_model_test.py b/caffe2/python/data_parallel_model_test.py index 653838d0587f..36f9565aed1f 100644 --- a/caffe2/python/data_parallel_model_test.py +++ b/caffe2/python/data_parallel_model_test.py @@ -5,7 +5,7 @@ from __future__ import print_function import numpy as np import unittest from caffe2.proto import caffe2_pb2 -from caffe2.python import core, workspace, data_parallel_model +from caffe2.python import core, workspace, data_parallel_model, cnn from caffe2.python.test_util import TestCase @@ -21,17 +21,42 @@ class GPUDataParallelModelTest(TestCase): ).astype(np.float32) label = np.dot(data, perfect_model)[:, np.newaxis] - model = data_parallel_model.GPUDataParallelModel( - gpu_devices, order="NHWC", name="fake") + def input_builder_fun(model): + return None - fc = model.FC("data", "fc", perfect_model.size, 1, - ("ConstantFill", {}), ("ConstantFill", {}), axis=0) - sq = model.SquaredL2Distance([fc, "label"], "sq") - loss = model.AveragedLoss(sq, "loss") - model.AddGradientOperators([loss]) - model.SGD(-0.1) - model.RunAllOnGPU() + def model_build_fun(model): + fc = model.FC("data", "fc", perfect_model.size, 1, + ("ConstantFill", {}), ("ConstantFill", {}), axis=0) + sq = model.SquaredL2Distance([fc, "label"], "sq") + loss = model.AveragedLoss(sq, "loss") + return [loss] + def param_update_fun(model): + ITER = model.Iter("ITER") + LR = model.net.LearningRate( + [ITER], + "LR", + base_lr=(-0.1 / len(gpu_devices)), + policy="fixed", + ) + ONE = model.param_init_net.ConstantFill( + [], "ONE", shape=[1], value=1.0, + ) + for param in model.GetParams(): + grad = model.param_to_grad[param] + model.WeightedSum([param, ONE, grad, LR], param) + + # Create model + model = cnn.CNNModelHelper(order="NHWC", name="fake") + data_parallel_model.Parallelize_GPU( + model, + input_builder_fun=input_builder_fun, + forward_pass_builder_fun=model_build_fun, + param_update_builder_fun=param_update_fun, + devices=gpu_devices, + ) + + # Feed some data for gpu_id in gpu_devices: with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, gpu_id)): workspace.FeedBlob( @@ -39,6 +64,7 @@ class GPUDataParallelModelTest(TestCase): workspace.FeedBlob( "gpu_{}/label".format(gpu_id), label[0]) + workspace.RunNetOnce(model.param_init_net) workspace.CreateNet(model.net) diff --git a/caffe2/python/dataio.py b/caffe2/python/dataio.py index 6878afc66802..75fa24c22d5f 100644 --- a/caffe2/python/dataio.py +++ b/caffe2/python/dataio.py @@ -20,7 +20,8 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals from caffe2.python import core -from caffe2.python.schema import Field, from_blob_list +from caffe2.python.schema import Field, Struct, from_blob_list +import numpy as np class Reader(object): @@ -36,6 +37,9 @@ class Reader(object): assert self._schema is not None, 'Schema not provided for this reader.' return self._schema + def _set_schema(self, schema): + self._schema = schema + def setup_ex(self, init_net, finish_net): """Nets to be executed once at startup and finish. Experimental extension. Don't use yet""" @@ -152,6 +156,11 @@ class Writer(object): that no more data will be written. """ + _schema = None + + def schema(self): + return self._schema + def write(self, writer_net, fields): """Add operations to `writer_net` that write the next batch of data. @@ -166,6 +175,7 @@ class Writer(object): def write_record(self, writer_net, fields): if isinstance(fields, Field): + self._schema = fields fields = fields.field_blobs() self.write(writer_net, fields) @@ -183,6 +193,7 @@ class Writer(object): self, fields, local_init_net, local_finish_net, stop_blob=None): """Experimental extension to the interface. Don't use yet.""" if isinstance(fields, Field): + self._schema = fields fields = fields.field_blobs() if stop_blob is None: stop_blob = local_init_net.NextName("dequeue_status") @@ -197,3 +208,126 @@ class Writer(object): of them. """ pass + + +class ReaderBuilder(object): + """ Allow usage of a reader in distributed fashion. """ + def schema(self): + raise NotImplementedError() + + def enqueue_splits(self, net, split_queue): + raise NotImplementedError() + + def splits(self, net): + raise NotImplementedError() + + def new_reader(self, split_queue): + raise NotImplementedError() + + +class Pipe(object): + def __init__(self, schema=None, obj_key=None): + self._num_writers = 0 + self._num_readers = 0 + self._schema = schema + self._obj_key = obj_key + + def schema(self): + return self._schema + + def setup(self, global_init_net): + pass + + def reader(self): + raise NotImplementedError() + + def writer(self): + raise NotImplementedError() + + def num_readers(self): + return self._num_readers + + def num_writers(self): + return self._num_writers + + def _new_writer(self, writer_schema, writer_init_net): + if writer_schema is not None and self._schema is None: + self._schema = writer_schema + self._num_writers += 1 + if self._obj_key is not None: + writer_init_net.add_attribute(self._obj_key, self) + + def _new_reader(self, reader_init_net): + self._num_readers += 1 + if self._obj_key is not None: + reader_init_net.add_attribute(self._obj_key, self) + + +class CounterReader(Reader): + """ Reader that produces increasing integers. """ + def __init__(self): + Reader.__init__(self, schema=Struct(('iter', np.int64))) + self.counter = None + self.should_stop = None + + def setup_ex(self, global_init_net, global_finish_net): + if self.counter is None: + self.counter = global_init_net.CreateCounter([], init_count=0) + self.should_stop = global_init_net.ConstantFill( + [], shape=[], dtype=core.DataType.BOOL, value=False) + + def read_ex(self, local_init_net, local_finish_net): + count_net = core.Net('limited_reader_counter') + value = count_net.CountUp([self.counter], 1) + return [count_net], self.should_stop, [value] + + +class ReaderWithLimit(Reader): + """ Reader that stops after `num_iter` calls. """ + def __init__(self, reader, num_iter=1): + Reader.__init__(self, schema=reader._schema) + self.reader = reader + self.counter = None + self.num_iter = num_iter + self._data_finished = None + + def setup_ex(self, global_init_net, global_finish_net): + if self._data_finished is None: + self.counter = global_init_net.CreateCounter( + [], init_count=int(self.num_iter)) + self.reader.setup_ex(global_init_net, global_finish_net) + self._data_finished = global_init_net.ConstantFill( + [], shape=[], value=False, dtype=core.DataType.BOOL) + + def read_ex(self, local_init_net, local_finish_net): + """ 1. check if we reached number of iterations """ + count_net = core.Net('limited_reader_counter') + should_stop = count_net.CountDown([self.counter], 1) + + """ 2. call original reader """ + nets, local_data_finished, fields = self.reader.read_ex( + local_init_net, local_finish_net) + self._set_schema(self.reader._schema) + + """ 3. check if original reader is done. """ + check_done_net = core.Net('limited_reader_post') + check_done_net.Copy(local_data_finished, should_stop) + check_done_net.Copy([local_data_finished], [self._data_finished]) + + # this relies on `should_stop` being called after each net. + return [count_net] + nets + [check_done_net], should_stop, fields + + def data_finished(self): + """ + Return a blob that can be checked after the end of the reading task, + which will contain a scalar float indicating whether the underlying + reader has been exhausted (True) or whether we stopped because reached + the limit of iterations (False). + """ + assert self._data_finished is not None, ( + 'read_record must be called before data_finished()') + return self._data_finished + + +def CountUntil(num_iter): + return ReaderWithLimit(CounterReader(), num_iter) diff --git a/caffe2/python/dataio_test.py b/caffe2/python/dataio_test.py new file mode 100644 index 000000000000..9494944c159f --- /dev/null +++ b/caffe2/python/dataio_test.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python.dataio import ReaderWithLimit +from caffe2.python.dataset import Dataset +from caffe2.python.pipeline import pipe +from caffe2.python.schema import Struct, NewRecord, FeedRecord +from caffe2.python.session import LocalSession +from caffe2.python.task import TaskGroup +from caffe2.python.test_util import TestCase +from caffe2.python import core, workspace +import numpy as np + + +class TestReaderWithLimit(TestCase): + def test_reader_with_limit(self): + ws = workspace.C.Workspace() + session = LocalSession(ws) + + """ 1. feed full dataset """ + src_init = core.Net('src_init') + src_values = Struct(('label', np.array(range(100)))) + src_blobs = NewRecord(src_init, src_values) + src_ds = Dataset(src_blobs) + FeedRecord(src_blobs, src_values, ws) + ws.run(src_init) + + """ 2. Read with limit smaller than size of dataset """ + dst_init = core.Net('dst_init') + dst_ds = Dataset(src_values.clone_schema()) + dst_ds.init_empty(dst_init) + ws.run(dst_init) + + with TaskGroup() as tg: + reader = ReaderWithLimit(src_ds.reader(), num_iter=10) + pipe(reader, dst_ds.writer(), num_threads=8) + session.run(tg) + self.assertFalse(ws.blobs[str(reader.data_finished())].fetch()) + self.assertEquals( + sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(10)) + + """ 3. Read with limit larger than size of dataset """ + ws.run(dst_init) + with TaskGroup() as tg: + reader = ReaderWithLimit(src_ds.reader(), num_iter=110) + pipe(reader, dst_ds.writer(), num_threads=8) + session.run(tg) + self.assertEquals( + sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(100)) + self.assertTrue(ws.blobs[str(reader.data_finished())].fetch()) diff --git a/caffe2/python/dataset.py b/caffe2/python/dataset.py index ebdbb0344bf5..84db02cde173 100644 --- a/caffe2/python/dataset.py +++ b/caffe2/python/dataset.py @@ -16,25 +16,33 @@ from __future__ import unicode_literals from caffe2.python import core, workspace from caffe2.python.dataio import Reader, Writer from caffe2.python.schema import ( - Struct, from_blob_list, Field, from_column_list) + Struct, from_blob_list, Field, from_column_list, InitEmptyRecord) import numpy as np class _DatasetReader(Reader): - def __init__(self, content, cursor, name, batch_size=1): + def __init__(self, dataset, name, batch_size=1): """Don't call this directly. Instead, use dataset.reader()""" - assert isinstance(content, Field) - Reader.__init__(self, content) - self._content = content - self.cursor = cursor - self.name = name + Reader.__init__(self, dataset.content()) + self.dataset = dataset + self.name = name or (dataset.name + '_cursor') self.batch_size = batch_size + self.cursor = None + + def setup_ex(self, init_net, exit_net): + if self.cursor is None: + self.cursor = init_net.CreateTreeCursor( + [], + [self.name], + fields=self.dataset.fields) def read(self, read_net): + assert self.cursor, 'setup not called.' + content = self.dataset.content() with core.NameScope(read_net.NextName(self.name)): fields = read_net.ReadNextBatch( - [self.cursor] + self._content.field_blobs(), - self._content.field_names(), + [self.cursor] + content.field_blobs(), + content.field_names(), batch_size=self.batch_size) if type(fields) is core.BlobReference: fields = [fields] @@ -45,37 +53,45 @@ class _DatasetReader(Reader): class _DatasetRandomReader(Reader): - def __init__(self, content, cursor, name, indices, batch_size=1): + def __init__(self, dataset, name, indices, batch_size=1): """Don't call this directly. Instead, use dataset.random_reader()""" - Reader.__init__(self, content) - self._content = content - self.cursor = cursor - self.name = name + Reader.__init__(self, dataset.content()) + self.dataset = dataset + self.cursor = None + self.name = name or (dataset.name + '_cursor') self.indices = indices self.batch_size = batch_size + def setup_ex(self, init_net, exit_net): + if self.cursor is None: + self.cursor = init_net.CreateTreeCursor( + [], + [self.name], + fields=self.dataset.fields) + def reset(self, net): net.ResetCursor([self.cursor], []) def computeoffset(self, net): self.reset(net) offsets = net.ComputeOffset( - [self.cursor] + self._content.field_blobs(), + [self.cursor] + self.dataset.content().field_blobs(), 'offsets') self.offsets = offsets def sort_and_shuffle(self, net, sort_by_field=None, shuffle_size=1, batch_size=1): # no sorting by default + content = self.dataset.content() sort_by_field_idx = -1 if sort_by_field: - assert sort_by_field in self._content.field_names(), ( + assert sort_by_field in content.field_names(), ( 'Must be valid field.') - sort_by_field_idx = self._content.field_names().index(sort_by_field) + sort_by_field_idx = content.field_names().index(sort_by_field) self.reset(net) indices = net.SortAndShuffle( - [self.cursor] + self._content.field_blobs(), + [self.cursor] + content.field_blobs(), 'indices', sort_by_field_idx=sort_by_field_idx, shuffle_size=shuffle_size, @@ -86,17 +102,21 @@ class _DatasetRandomReader(Reader): with core.NameScope(read_net.NextName(self.name)): fields = read_net.ReadRandomBatch( [self.cursor, self.indices, self.offsets] + ( - self._content.field_blobs()), - self._content.field_names(), + self.dataset.content().field_blobs()), + self.dataset.content().field_names(), batch_size=self.batch_size) return (read_net.IsEmpty([fields[0]]), fields) class _DatasetWriter(Writer): - def __init__(self, content, init_net): + def __init__(self, content): """Don't call this directly. Use dataset.writer() instead.""" self._content = content - self.mutex = init_net.CreateMutex([]) + self.mutex = None + + def setup_ex(self, init_net, exit_net): + if self.mutex is None: + self.mutex = init_net.CreateMutex([]) def write(self, writer_net, fields): """ @@ -108,6 +128,7 @@ class _DatasetWriter(Writer): writer_net: The net that will contain the Append operators. fields: A list of BlobReference to be appeneded to this dataset. """ + assert self.mutex is not None, 'setup not called.' field_blobs = self._content.field_blobs() assert len(fields) == len(field_blobs), ( 'Expected %s fields, got %s.' % (len(field_blobs), len(fields))) @@ -147,6 +168,7 @@ def execution_step_with_progress(name, init_net, substeps, rows_read): concurrent_substeps=True, report_interval=5) + class Dataset(object): """Represents an in-memory dataset with fixed schema. @@ -177,7 +199,7 @@ class Dataset(object): self.fields = fields.field_names() self.field_types = fields.field_types() self.name = name or 'dataset' - self.field_blobs = None + self.field_blobs = fields.field_blobs() if fields.has_blobs() else None def init_empty(self, init_net): """Initialize the blobs for this dataset with empty values. @@ -185,8 +207,8 @@ class Dataset(object): Empty arrays will be immediately fed into the current workspace, and `init_net` will take those blobs as external inputs. """ - self.field_blobs = [init_net.ConstantFill( - [], shape=[0], run_once=False) for f in self.fields] + self.field_blobs = InitEmptyRecord( + init_net, self.schema.clone_schema()).field_blobs() def init_from_dataframe(self, net, dataframe): """Initialize the blobs for this dataset from a Pandas dataframe. @@ -227,7 +249,7 @@ class Dataset(object): """ return self.field_types - def reader(self, init_net, cursor_name=None, batch_size=1): + def reader(self, init_net=None, cursor_name=None, batch_size=1): """Create a Reader object that is used to iterate through the dataset. This will append operations to `init_net` that create a TreeCursor, @@ -246,14 +268,12 @@ class Dataset(object): iterate through the dataset. """ assert self.field_blobs, 'Dataset not initialized.' - cursor_name = cursor_name or (self.name + '_cursor') - cursor = init_net.CreateTreeCursor( - [], - [cursor_name], - fields=self.fields) - return _DatasetReader(self.content(), cursor, cursor_name, batch_size) + reader = _DatasetReader(self, cursor_name, batch_size) + if init_net is not None: + reader.setup_ex(init_net, None) + return reader - def random_reader(self, init_net, indices=None, cursor_name=None, + def random_reader(self, init_net=None, indices=None, cursor_name=None, batch_size=1): """Create a Reader object that is used to iterate through the dataset. @@ -271,15 +291,12 @@ class Dataset(object): iterate through the dataset according to indices. """ assert self.field_blobs, 'Dataset not initialized.' - cursor_name = cursor_name or (self.name + '_cursor') - cursor = init_net.CreateTreeCursor( - [], - [cursor_name], - fields=self.fields) - return _DatasetRandomReader( - self.content(), cursor, cursor_name, indices, batch_size) + reader = _DatasetRandomReader(self, cursor_name, indices, batch_size) + if init_net is not None: + reader.setup_ex(init_net, None) + return reader - def writer(self, init_net): + def writer(self, init_net=None): """Create a Writer that can be used to append entries into the dataset. NOTE: Currently, it is not safe to append to a dataset @@ -292,4 +309,7 @@ class Dataset(object): (currently not used) """ assert self.field_blobs, 'Dataset not initialized.' - return _DatasetWriter(self.content(), init_net) + writer = _DatasetWriter(self.content()) + if init_net is not None: + writer.setup_ex(init_net, None) + return writer diff --git a/caffe2/python/dyndep.py b/caffe2/python/dyndep.py index 8d0af75614d8..18960143ef57 100644 --- a/caffe2/python/dyndep.py +++ b/caffe2/python/dyndep.py @@ -30,7 +30,19 @@ def InitOpsLibrary(name): # time when an actual call is made. print('Ignoring {} as it is not a valid file.'.format(name)) return + _init_impl(name) + + +_IMPORTED_DYNDEPS = set() + + +def GetImportedOpsLibraries(): + return _IMPORTED_DYNDEPS + + +def _init_impl(path): + _IMPORTED_DYNDEPS.add(path) with extension_loader.DlopenGuard(): - ctypes.CDLL(name) + ctypes.CDLL(path) # reinitialize available ops core.RefreshRegisteredOperators() diff --git a/caffe2/python/experiment_util.py b/caffe2/python/experiment_util.py index 333ad7d8612d..c4ce9a8a41f3 100644 --- a/caffe2/python/experiment_util.py +++ b/caffe2/python/experiment_util.py @@ -24,6 +24,8 @@ class ModelTrainerLog(): self.logstr("# %s" % str(runtime_args)) self.headers = None self.start_time = time.time() + self.last_time = self.start_time + self.last_input_count = 0 def logstr(self, str): with open(self.filename, "a") as f: @@ -33,11 +35,15 @@ class ModelTrainerLog(): def log(self, input_count, batch_count, additional_values): logdict = OrderedDict() + delta_t = time.time() - self.last_time + delta_count = input_count - self.last_input_count + self.last_time = time.time() + self.last_input_count = input_count logdict['time'] = time.time() - self.start_time logdict['input_counter'] = input_count logdict['batch_count'] = batch_count - if logdict['time'] > 0: - logdict['inputs_per_sec'] = input_count / logdict['time'] + if delta_t > 0: + logdict['inputs_per_sec'] = delta_count / delta_t else: logdict['inputs_per_sec'] = 0.0 diff --git a/caffe2/python/hsm_test.py b/caffe2/python/hsm_test.py index 12445324598f..f9cd060523ae 100644 --- a/caffe2/python/hsm_test.py +++ b/caffe2/python/hsm_test.py @@ -21,13 +21,25 @@ import caffe2.python.hsm_util as hsmu # 0,1,2 3,4 tree = hsm_pb2.TreeProto() words = [[0, 1, 2], [3, 4], [5, 6, 7, 8]] -node1 = hsmu.create_node_with_words(words[0]) -node2 = hsmu.create_node_with_words(words[1]) -node3 = hsmu.create_node_with_words(words[2]) -node4 = hsmu.create_node_with_nodes([node1, node2]) -node = hsmu.create_node_with_nodes([node4, node3]) +node1 = hsmu.create_node_with_words(words[0], "node1") +node2 = hsmu.create_node_with_words(words[1], "node2") +node3 = hsmu.create_node_with_words(words[2], "node3") +node4 = hsmu.create_node_with_nodes([node1, node2], "node4") +node = hsmu.create_node_with_nodes([node4, node3], "node5") tree.root_node.MergeFrom(node) +# structure: +# node5: [0, 2, ["node4", "node3"]] # offset, length, "node4, node3" +# node4: [2, 2, ["node1", "node2"]] +# node1: [4, 3, [0, 1 ,2]] +# node2: [7, 2, [3, 4] +# node3: [9, 4, [5, 6, 7, 8] +struct = [[0, 2, ["node4", "node3"], "node5"], + [2, 2, ["node1", "node2"], "node4"], + [4, 3, [0, 1, 2], "node1"], + [7, 2, [3, 4], "node2"], + [9, 4, [5, 6, 7, 8], "node3"]] + # Internal util to translate input tree to list of (word_id,path). serialized # hierarchy is passed into the operator_def as a string argument, hierarchy_proto = hsmu.create_hierarchy(tree) @@ -35,8 +47,82 @@ arg = caffe2_pb2.Argument() arg.name = "hierarchy" arg.s = hierarchy_proto.SerializeToString() +beam = 5 +args_search = [] +arg_search = caffe2_pb2.Argument() +arg_search.name = "tree" +arg_search.s = tree.SerializeToString() +args_search.append(arg_search) +arg_search = caffe2_pb2.Argument() +arg_search.name = "beam" +arg_search.f = beam +args_search.append(arg_search) + class TestHsm(hu.HypothesisTestCase): + def test_hsm_search(self): + samples = 10 + dim_in = 5 + X = np.random.rand(samples, dim_in).astype(np.float32) - 0.5 + w = np.random.rand(hierarchy_proto.size, dim_in) \ + .astype(np.float32) - 0.5 + b = np.random.rand(hierarchy_proto.size).astype(np.float32) - 0.5 + labels = np.array([np.random.randint(0, 8) for i in range(samples)]) \ + .astype(np.int32) + + workspace.GlobalInit(['caffe2']) + workspace.FeedBlob("data", X) + workspace.FeedBlob("weights", w) + workspace.FeedBlob("bias", b) + workspace.FeedBlob("labels", labels) + op = core.CreateOperator( + 'HSoftmaxSearch', + ['data', 'weights', 'bias'], + ['names', 'scores'], + 'HSoftmaxSearch', + arg=args_search) + workspace.RunOperatorOnce(op) + names = workspace.FetchBlob('names') + scores = workspace.FetchBlob('scores') + + def simulation_hsm_search(): + names = [] + scores = [] + for line in struct: + s, e = line[0], line[0] + line[1] + score = np.dot(X, w[s:e].transpose()) + b[s:e] + score = np.exp(score - np.max(score, axis=1, keepdims=True)) + score /= score.sum(axis=1, keepdims=True) + score = -np.log(score) + + score = score.transpose() + idx = -1 + for j, n in enumerate(names): + if n == line[3]: + idx = j + score += scores[j] + if idx == -1: + score[score > beam] = np.inf + else: + score[score - scores[idx] > beam] = np.inf + + for i, name in enumerate(line[2]): + scores.append(score[i]) + names.append(name) + scores = np.vstack(scores) + return names, scores.transpose() + + p_names, p_scores = simulation_hsm_search() + idx = np.argsort(p_scores, axis=1) + p_scores = np.sort(p_scores, axis=1) + p_names = np.array(p_names)[idx] + for i in range(names.shape[0]): + for j in range(names.shape[1]): + if names[i][j]: + assert(names[i][j] == p_names[i][j]) + self.assertAlmostEqual( + scores[i][j], p_scores[i][j], delta=0.001) + def test_hsm_run_once(self): workspace.GlobalInit(['caffe2']) workspace.FeedBlob("data", @@ -44,7 +130,7 @@ class TestHsm(hu.HypothesisTestCase): workspace.FeedBlob("weights", np.random.randn(1000, 100).astype(np.float32)) workspace.FeedBlob("bias", np.random.randn(1000).astype(np.float32)) - workspace.FeedBlob("labels", np.random.randn(1000).astype(np.int32)) + workspace.FeedBlob("labels", np.random.rand(1000).astype(np.int32) * 9) op = core.CreateOperator( 'HSoftmax', ['data', 'weights', 'bias', 'labels'], @@ -59,7 +145,7 @@ class TestHsm(hu.HypothesisTestCase): cpu_device_option = caffe2_pb2.DeviceOption() grad_checker = gradient_checker.GradientChecker( 0.01, 0.05, cpu_device_option, "default") - samples = 10 + samples = 9 dim_in = 5 X = np.zeros((samples, dim_in)).astype(np.float32) + 1 w = np.zeros((hierarchy_proto.size, dim_in)).astype(np.float32) + 1 diff --git a/caffe2/python/hsm_util.py b/caffe2/python/hsm_util.py index bcfedac8f0c4..f64eda2f9acb 100644 --- a/caffe2/python/hsm_util.py +++ b/caffe2/python/hsm_util.py @@ -12,15 +12,17 @@ from caffe2.proto import hsm_pb2 ''' -def create_node_with_words(words): +def create_node_with_words(words, name='node'): node = hsm_pb2.NodeProto() + node.name = name for word in words: node.word_ids.append(word) return node -def create_node_with_nodes(nodes): +def create_node_with_nodes(nodes, name='node'): node = hsm_pb2.NodeProto() + node.name = name for child_node in nodes: new_child_node = node.children.add() new_child_node.MergeFrom(child_node) @@ -41,6 +43,7 @@ def create_hierarchy(tree_proto): return path_proto def recursive_path_builder(node_proto, path, hierarchy_proto, max_index): + node_proto.offset = max_index path.append([max_index, len(node_proto.word_ids) + len(node_proto.children), 0]) max_index += len(node_proto.word_ids) + len(node_proto.children) diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index 5cb2e7d5d193..b3e69aa7a515 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -150,6 +150,23 @@ class TestOperators(hu.HypothesisTestCase): self.assertDeviceChecks(dc, op, [X1, X2], [0]) self.assertGradientChecks(gc, op, [X1, X2], 0, [0]) + @given(inputs=hu.tensors(n=2), **hu.gcs) + def test_max(self, inputs, gc, dc): + op = core.CreateOperator("Max", ["X1", "X2"], ["Y"]) + + X1, X2 = inputs + # Make X1 and X2 far from each other, since X1=X2 is not differentiable + # and the step size of gradient checker is 0.05 + X1[np.logical_and(X1 >= X2 - 0.05, X1 <= X2)] -= 0.05 + X1[np.logical_and(X1 <= X2 + 0.05, X1 >= X2)] += 0.05 + self.assertDeviceChecks(dc, op, [X1, X2], [0]) + for i in range(2): + self.assertGradientChecks(gc, op, [X1, X2], i, [0]) + + def elementwise_max(X, Y): + return [np.maximum(X, Y)] + self.assertReferenceChecks(gc, op, [X1, X2], elementwise_max) + def test_add(self): def ref(x, y): return (x + y, ) @@ -227,6 +244,11 @@ class TestOperators(hu.HypothesisTestCase): self.assertDeviceChecks(dc, op, [X], [0]) self.assertReferenceChecks(gc, op, [X], softsign) + if inplace: + with self.assertRaises(Exception): + self.assertGradientChecks(gc, op, [X], 0, [0]) + else: + self.assertGradientChecks(gc, op, [X], 0, [0]) @given( device_options=st.lists( @@ -261,8 +283,9 @@ class TestOperators(hu.HypothesisTestCase): @given(axis=st.integers(min_value=1, max_value=4), num_output=st.integers(min_value=4, max_value=8), + engine=st.sampled_from(["", "PACKED"]), **hu.gcs) - def test_fully_connected_axis(self, axis, num_output, gc, dc): + def test_fully_connected_axis(self, axis, num_output, engine, gc, dc): np.random.seed(1) X = np.random.randn(1, 2, 3, 2, 1).astype(np.float32) @@ -281,6 +304,7 @@ class TestOperators(hu.HypothesisTestCase): "FC", ["X", "W", "b"], ["Y"], + engine=engine, axis=axis) for name, param in [("X", X), ("W", W), ("b", b)]: self.ws.create_blob(name).feed(param) @@ -354,16 +378,15 @@ class TestOperators(hu.HypothesisTestCase): axis=st.integers(0, 3), num_inputs=st.integers(2, 4), **hu.gcs) def test_depth_concat(self, ndim, axis, num_inputs, gc, dc): - if (axis >= ndim): - return + assume(axis < ndim) input_names = ['X0', 'X1', 'X2', 'X3'][:num_inputs] shape = [2, 3, 5, 7][:ndim] - individual_dims = [11, 13, 17, 19][:num_inputs] + individual_dims = [1, 2, 3, 4, 5][:num_inputs] inputs = [] for i in range(num_inputs): # Sets a unique dim and create the input. shape[axis] = individual_dims[i] - inputs.append(np.random.rand(*shape).astype(np.float32)) + inputs.append(np.random.randn(*shape).astype(np.float32)) op = core.CreateOperator("Concat", input_names, ["Y", "Y_dims"], axis=axis) self.assertDeviceChecks(dc, op, inputs, [0]) @@ -376,7 +399,7 @@ class TestOperators(hu.HypothesisTestCase): def test_depth_concat_with_order(self, num_inputs, order, gc, dc): input_names = ['X0', 'X1', 'X2', 'X3'][:num_inputs] shape = [2, 3, 5, 7] - individual_dims = [11, 13, 17, 19][:num_inputs] + individual_dims = [1, 2, 3, 4][:num_inputs] inputs = [] for i in range(num_inputs): # Sets a unique dim and create the input. diff --git a/caffe2/python/layer_model_helper.py b/caffe2/python/layer_model_helper.py new file mode 100644 index 000000000000..95eb915d5d92 --- /dev/null +++ b/caffe2/python/layer_model_helper.py @@ -0,0 +1,295 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, model_helper, schema +from caffe2.python.layers import layers + +from functools import partial + +import logging +import numpy as np +logger = logging.getLogger(__name__) + + +class LayerModelHelper(model_helper.ModelHelperBase): + """ + Model helper for building models on top of layers abstractions. + + Each layer is the abstraction that is higher level than Operator. Layer + is responsible for ownership of it's own parameters and can easily be + instantiated in multiple nets possible with different sets of ops. + As an example: one can easily instantiate predict and train nets from + the same set of layers, where predict net will have subset of the + operators from train net. + """ + + def __init__(self, name, input_feature_schema, trainer_extra_schema): + super(LayerModelHelper, self).__init__(name=name) + self._layer_names = set() + self._layers = [] + + # optimizer bookkeeping + self.param_to_optim = {} + + self._default_optimizer = None + self._loss = None + self._output_schema = None + + # Connect Schema to self.net. That particular instance of schmea will be + # use for generation of the Layers accross the network and would be used + # for connection with Readers. + self._input_feature_schema = schema.NewRecord( + self.net, + input_feature_schema + ) + self._trainer_extra_schema = schema.NewRecord( + self.net, + trainer_extra_schema + ) + + self._init_global_constants() + self.param_init_net = self.create_init_net('param_init_net') + + def add_global_constant(self, name, array, dtype=None): + # This is global namescope for constants. They will be created in all + # init_nets and there should be very few of them. + assert name not in self.global_constants + self.global_constants[name] = core.BlobReference( + self.net.NextName(name)) + + if dtype is None: + array = np.array(array) + else: + array = np.array(array, dtype=dtype) + + # TODO: make GivenTensor generic + op_name = None + if array.dtype == np.int32: + op_name = 'GivenTensorIntFill' + elif array.dtype == np.int64: + op_name = 'GivenTensorInt64Fill' + elif array.dtype == np.str: + op_name = 'GivenTensorStringFill' + else: + op_name = 'GivenTensorFill' + + self.global_constant_initializers.append( + core.CreateOperator(op_name, + [], + self.global_constants[name], + shape=array.shape, + values=array.flatten().tolist() + ) + ) + return self.global_constants[name] + + def _init_global_constants(self): + self.global_constants = {} + self.global_constant_initializers = [] + self.add_global_constant('ONE', 1.0) + self.add_global_constant('ZERO', 0.0) + self.add_global_constant('ZERO_RANGE', [0, 0], dtype='int32') + + def _add_global_constants(self, init_net): + for initializer_op in self.global_constant_initializers: + init_net._net.op.extend([initializer_op]) + + def create_init_net(self, name): + init_net = core.Net(name) + self._add_global_constants(init_net) + return init_net + + def next_block_name(self, prefix): + return prefix + "_{}".format( + len(filter(lambda x: x.startswith(prefix), self._layer_names))) + + def add_layer(self, layer): + self._layers.append(layer) + for param in layer.get_parameters(): + self.param_to_optim[str(param.parameter)] = param.optimizer + + # The primary value of adding everything to self.net - generation of the + # operators right away, i.e. if error happens it'll be detected + # immediately. Other then this - create_x_net should be called. + layer.add_operators(self.net, self.param_init_net) + return layer.get_output_schema() + + @property + def default_optimizer(self): + return self._default_optimizer + + @default_optimizer.setter + def default_optimizer(self, optimizer): + self._default_optimizer = optimizer + + @property + def input_feature_schema(self): + return self._input_feature_schema + + @property + def trainer_extra_schema(self): + return self._trainer_extra_schema + + @property + def output_schema(self): + assert self._output_schema is not None + return self._output_schema + + @output_schema.setter + def output_schema(self, schema): + assert self._output_schema is None + self._output_schema = schema + + @property + def loss(self): + assert self._loss is not None + return self._loss + + @loss.setter + def loss(self, loss): + assert self._loss is None + self._loss = loss + + def __getattr__(self, layer): + if not layers.layer_exists(layer): + raise ValueError( + "Tring to create non-registered layer: {0}".format(layer)) + + def wrapper(*args, **kwargs): + return self.add_layer( + layers.create_layer(layer, self, *args, **kwargs)) + return wrapper + + @property + def layers(self): + return self._layers + + # TODO(amalevich): Optimizer should not really in model. Move it out. + # Copy over from another Helper + def SgdOptim(self, base_lr=0.01, policy='fixed', **kwargs): + return partial(self.Sgd, base_lr=base_lr, policy=policy, **kwargs) + + def AdagradOptim(self, alpha=0.01, epsilon=1e-4, **kwargs): + return partial(self.Adagrad, alpha=alpha, epsilon=epsilon, **kwargs) + + def FtrlOptim(self, alpha=0.01, beta=1e-4, lambda1=0, lambda2=0, **kwargs): + return partial(self.Ftrl, alpha=alpha, beta=beta, lambda1=lambda1, + lambda2=lambda2, **kwargs) + + def _GetOne(self): + return self.global_constants['ONE'] + + def Adagrad(self, net, param_init_net, + param, grad, alpha, epsilon, dedup_indices=False, + engine=''): + if alpha <= 0: + return + + param_square_sum = param_init_net.ConstantFill( + [param], + core.ScopedBlobReference(param + "_square_sum"), + value=0.0 + ) + # Set learning rate to negative so that we can add the grad to param + # directly later. + lr = param_init_net.ConstantFill( + [], core.ScopedBlobReference(param + "_lr"), value=-alpha) + if isinstance(grad, core.GradientSlice): + if dedup_indices: + grad = net.DeduplicateGradientSlices(grad) + + net.SparseAdagrad( + [param, param_square_sum, grad.indices, grad.values, lr], + [param, param_square_sum], + epsilon=epsilon, + engine=engine + ) + + else: + net.Adagrad( + [param, param_square_sum, grad, lr], + [param, param_square_sum], + epsilon=epsilon, + engine=engine + ) + + def Ftrl(self, net, param_init_net, + param, grad, alpha, beta, lambda1, lambda2, + dedup_indices=False, engine=''): + if alpha <= 0: + return + + nz = param_init_net.ConstantFill( + [param], + core.ScopedBlobReference(param + "_ftrl_nz"), + extra_shape=[2], + value=0.0 + ) + if isinstance(grad, core.GradientSlice): + if dedup_indices: + grad = net.DeduplicateGradientSlices(grad) + + net.SparseFtrl( + [param, nz, grad.indices, grad.values], + [param, nz], + engine=engine, + alpha=alpha, + beta=beta, + lambda1=lambda1, + lambda2=lambda2 + ) + else: + net.Ftrl( + [param, nz, grad], + [param, nz], + engine=engine, + alpha=alpha, + beta=beta, + lambda1=lambda1, + lambda2=lambda2 + ) + + def Sgd(self, net, param_init_net, + param, grad, base_lr, policy, momentum=0.0, **kwargs): + if (base_lr <= 0): + return + # Set learning rate to negative so that we can add the grad to param + # directly later. + + # TODO(amalevich): Get rid of iter duplication if other parts are good + # enough + lr = net.LearningRate( + [net.Iter([], 1)], + core.ScopedBlobReference(param + "_lr"), + base_lr=-base_lr, + policy=policy, + **kwargs + ) + + if momentum > 0: + momentum_data = param_init_net.ConstantFill( + param, core.ScopedBlobReference(param + "_momentum"), value=0.) + + if isinstance(grad, core.GradientSlice): + assert momentum == 0., "Doesn't support momentum for sparse" + net.ScatterWeightedSum( + [param, self._GetOne(), + grad.indices, grad.values, lr], + param + ) + else: + if momentum > 0.: + net.MomentumSGD( + [grad, momentum_data, lr], [grad, momentum_data], + momentum=momentum, + nesterov=1) + coeff = self._GetOne() + else: + coeff = lr + + net.WeightedSum( + [param, self._GetOne(), grad, coeff], + param + ) diff --git a/caffe2/python/layer_model_instantiator.py b/caffe2/python/layer_model_instantiator.py new file mode 100644 index 000000000000..a6959f3cfef8 --- /dev/null +++ b/caffe2/python/layer_model_instantiator.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, schema +from caffe2.python.layers.layers import InstantiationContext +from caffe2.python.layers.tags import Tags + +import itertools + + +def generate_predict_net(model): + predict_net = core.Net('predict_net') + + for layer in model.layers: + if Tags.TRAIN_ONLY not in layer.tags: + layer.add_operators( + predict_net, context=InstantiationContext.PREDICTION) + return predict_net + + +def generate_training_nets(model): + train_net = core.Net('train_net') + train_init_net = model.create_init_net('train_init_net') + + loss = model.loss + for layer in model.layers: + layer.add_operators(train_net, train_init_net) + grad_map = train_net.AddGradientOperators(loss.field_blobs()) + for param, optimizer in model.param_to_optim.items(): + if not optimizer: + optimizer = model.default_optimizer + optimizer(train_net, train_init_net, param, grad_map[str(param)]) + + trainer_schema = schema.Struct( + *itertools.chain( + model.trainer_extra_schema.get_children(), + model.input_feature_schema.get_children(), + ) + ) + + train_net.set_input_record(trainer_schema) + return train_init_net, train_net diff --git a/caffe2/python/layers/__init__.py b/caffe2/python/layers/__init__.py new file mode 100644 index 000000000000..b1ca067e5fae --- /dev/null +++ b/caffe2/python/layers/__init__.py @@ -0,0 +1,27 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from importlib import import_module +import pkgutil +import sys +from . import layers + + +def import_recursive(package): + """ + Takes a package and imports all modules underneath it + """ + pkg_dir = package.__path__ + module_location = package.__name__ + for (module_loader, name, ispkg) in pkgutil.iter_modules(pkg_dir): + module_name = "{}.{}".format(module_location, name) # Module/package + module = import_module(module_name) + if ispkg: + import_recursive(module) + +import_recursive(sys.modules[__name__]) + +for cls in layers.ModelLayer.__subclasses__(): + layers.register_layer(cls.__name__, cls) diff --git a/caffe2/python/layers/batch_lr_loss.py b/caffe2/python/layers/batch_lr_loss.py new file mode 100644 index 000000000000..bc6ccb4a2e35 --- /dev/null +++ b/caffe2/python/layers/batch_lr_loss.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, schema +from caffe2.python.layers.layers import ( + ModelLayer, +) +from caffe2.python.layers.tags import ( + Tags +) +import numpy as np + + +class BatchLRLoss(ModelLayer): + + def __init__(self, model, input_record, name='batch_lr_loss', **kwargs): + super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs) + + schema.is_schema_subset( + schema.Struct( + ('label', schema.Scalar()), + ('prediction', schema.Scalar()) + ), + input_record + ) + self.tags.update({Tags.TRAIN_ONLY}) + + self.output_schema = schema.Scalar( + np.float32, + core.BlobReference(model.net.NextName(self.name + '_output'))) + + # This should be a bit more complicated than it is right now + def add_ops(self, net): + class_probabilities = net.MakeTwoClass( + self.input_record.prediction.field_blobs()) + label = self.input_record.label.field_blobs() + if self.input_record.label.field_types()[0] != np.int32: + label = [net.Cast(label, to='int32')] + + xent = net.LabelCrossEntropy( + [class_probabilities] + label) + net.AveragedLoss(xent, self.output_schema.field_blobs()) diff --git a/caffe2/python/layers/concat.py b/caffe2/python/layers/concat.py new file mode 100644 index 000000000000..291e63d669f7 --- /dev/null +++ b/caffe2/python/layers/concat.py @@ -0,0 +1,56 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, schema +from caffe2.python.layers.layers import ( + ModelLayer, +) +import numpy as np + + +class Concat(ModelLayer): + + def __init__(self, model, input_record, axis=1, + name='concat', **kwargs): + super(Concat, self).__init__(model, name, input_record, **kwargs) + self.axis = axis + assert isinstance(input_record, schema.Struct),\ + "Incorrect input type. Excpected Struct, but received: {0}".\ + format(input_record) + + shapes = [] + for field_name, field_type in input_record.fields.items(): + assert isinstance(field_type, schema.Scalar),\ + "Incorrect input type. Excpected Scalar, but received: {0}".\ + format(field_type) + # Assume that first dimension is batch, so actual axis in shape is + # axis - 1 + assert len(field_type.field_type().shape) >= axis,\ + "Concat expects that limited dimensions of the input tensor" + shapes.append(list(field_type.field_type().shape)) + + concat_dim = 0 + for shape in shapes: + concat_dim += shape[axis - 1] + shape[axis - 1] = 0 + assert shape == shapes[0],\ + "Shapes {0} and {1} are not compatible for Concat".\ + format(shape, shapes[0]) + output_dims = shapes[0] + output_dims[axis - 1] = concat_dim + + self.output_schema = schema.Scalar( + (np.float32, output_dims), + core.BlobReference(model.net.NextName(self.name + '_output'))) + + def add_ops(self, net): + net.Concat( + self.input_record.field_blobs(), + [ + self.output_schema.field_blobs()[0], + net.NextName(str("_" + self.output_schema.field_blobs()[0] + + "_concat_dims"))], + axis=self.axis, + ) diff --git a/caffe2/python/layers/fc.py b/caffe2/python/layers/fc.py new file mode 100644 index 000000000000..dee1065e126e --- /dev/null +++ b/caffe2/python/layers/fc.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, schema +from caffe2.python.layers.layers import ( + ModelLayer, + LayerParameter +) +import math +import numpy as np + + +class FC(ModelLayer): + + def __init__(self, model, input_record, output_dims, weight_init=None, + bias_init=None, weight_optim=None, bias_optim=None, name='fc', + **kwargs): + super(FC, self).__init__(model, name, input_record, **kwargs) + assert isinstance(input_record, schema.Scalar), "Incorrect input type" + assert len(input_record.field_types()[0].shape) > 0,\ + "FC expects limited dimensions of the input tensor" + + input_dims = input_record.field_types()[0].shape[0] + + self.output_schema = schema.Scalar( + (np.float32, output_dims), + core.BlobReference(model.net.NextName(self.name + '_output')) + ) + + scale = math.sqrt(1.0 / input_dims) + weight_init = weight_init if weight_init else ( + 'UniformFill', {'min': -scale, 'max': scale}) + bias_init = bias_init if bias_init else ( + 'UniformFill', {'min': -scale, 'max': scale}) + + self.w = model.net.NextName(self.name + "_w") + self.b = model.net.NextName(self.name + "_b") + + self.params.append( + LayerParameter( + parameter=self.w, + initializer=core.CreateOperator(weight_init[0], + [], + self.w, + shape=[output_dims, input_dims], + **weight_init[1] + ), + optimizer=weight_optim)) + self.params.append( + LayerParameter( + parameter=self.b, + initializer=core.CreateOperator(bias_init[0], + [], + self.b, + shape=[output_dims, ], + **bias_init[1] + ), + optimizer=bias_optim)) + + def add_ops(self, net): + net.FC(self.input_record.field_blobs() + [self.w, self.b], + self.output_schema.field_blobs(), **self.kwargs) diff --git a/caffe2/python/layers/layers.py b/caffe2/python/layers/layers.py new file mode 100644 index 000000000000..c5e5f1de9d7b --- /dev/null +++ b/caffe2/python/layers/layers.py @@ -0,0 +1,87 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import schema +from caffe2.python.layers.tags import TagContext + +from collections import namedtuple +import numpy as np + +# Some types to simplify descriptions of things traveling between ops +IdList = schema.List(np.int64) +IdScoreList = schema.Map(np.int64, np.float32) + + +class InstantiationContext(object): + """ + List of contexts where layer could be instantitated + """ + TRAINING = 'training' + PREDICTION = 'prediction' + + +_LAYER_REGISTRY = {} + + +def register_layer(name, layer): + assert name not in _LAYER_REGISTRY, "{0} already exists".format(name) + _LAYER_REGISTRY[name] = layer + + +def layer_exists(name): + return name in _LAYER_REGISTRY + + +def create_layer(name, *args, **kwargs): + return _LAYER_REGISTRY[name](*args, **kwargs) + +# TODO(amalevich): Modify this to some better struct, something closer to +# ParameterInfo. +LayerParameter = namedtuple( + 'LayerParameter', ['parameter', 'optimizer', 'initializer']) + + +class ModelLayer(object): + + def __init__(self, model, prefix, input_record, tags=set(), **kwargs): + self.name = model.next_block_name(prefix) + self.model = model + self.kwargs = kwargs + self.input_record = input_record + self.output_schema = None + self.tags = set(tags) + self.tags.update(TagContext.current().tags) + self.params = [] + + def get_output_schema(self): + assert self.output_schema is not None, "Schema is not initialized" + return self.output_schema + + def get_parameters(self): + return self.params + + def add_operators(self, net, init_net=None, + context=InstantiationContext.TRAINING): + if context != InstantiationContext.PREDICTION: + assert init_net,\ + "Only prediction context can be used without init_net" + if init_net: + for param in self.params: + # TODO(amalevich): Either return back to lambdas, that add all + # params (looks a bit safer and breaking less abstractions) or + # extend Net interface to this type of operations better + init_net._net.op.extend([param.initializer]) + if context == InstantiationContext.TRAINING: + self.add_train_ops(net) + else: + self.add_ops(net) + + def add_ops(self, net): + raise NotImplementedError + + def add_train_ops(self, net): + # Default train layer implementation is completely matching predict + # layer implementation. + self.add_ops(net) diff --git a/caffe2/python/layers/simple_operator_layers.py b/caffe2/python/layers/simple_operator_layers.py new file mode 100644 index 000000000000..602f315116fa --- /dev/null +++ b/caffe2/python/layers/simple_operator_layers.py @@ -0,0 +1,67 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import schema +from caffe2.python.layers.layers import ( + ModelLayer, +) + + +def simple_init(self, model, input_record, *args, **kwargs): + ModelLayer.__init__(self, model, self.operator, input_record, **kwargs) + assert self.operator is not None, "Try to create invalid operator layer" + self.args = args + self.output_schema = schema.NewRecord(self.model.net, input_record) + + +def first_field_schema_init(self, model, input_record, *args, **kwargs): + ModelLayer.__init__(self, model, self.operator, input_record, **kwargs) + assert self.operator is not None, "Try to create invalid operator layer" + assert isinstance(input_record, schema.Struct),\ + "Operator {0} expects schema.Struct as input, received {1} instead".\ + format(self.operator, input_record) + self.args = args + self.output_schema = schema.NewRecord(self.model.net, input_record[0]) + + +def simple_add_ops(self, net): + getattr( + net, + self.operator)( + self.input_record.field_blobs(), + self.output_schema.field_blobs(), + *self.args, + **self.kwargs + ) + +_simple_operators = ['Softmax', 'Relu', 'Sigmoid', 'Tanh'] +_first_field_schema_operators = ['Add'] + +for operator in _simple_operators: + # Generate class instance with name 'operator', that is doing going to use + # simple_init and simple_add_ops implementations for __init__ and add_ops + # calls. It'll also get automatically registered in the registry. + type( + str(operator), + (ModelLayer,), + {'__init__': simple_init, + 'add_ops': simple_add_ops, + 'operator': operator + } + ) + +for operator in _first_field_schema_operators: + # Generate class instance with name 'operator', that is doing going to use + # first_field_schema_init and simple_add_ops implementations for __init__ + # and add_ops calls. It'll also get automatically registered in the + # registry. + type( + str(operator), + (ModelLayer,), + {'__init__': first_field_schema_init, + 'add_ops': simple_add_ops, + 'operator': operator + } + ) diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py new file mode 100644 index 000000000000..64c03e6e605a --- /dev/null +++ b/caffe2/python/layers/sparse_lookup.py @@ -0,0 +1,96 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, schema +from caffe2.python.layers.layers import ( + IdList, + IdScoreList, + LayerParameter, + ModelLayer, +) +import math +import numpy as np + + +class SparseLookup(ModelLayer): + _supported_reducers = ['LogMeanExp', 'LogSumExp', 'Max', 'Mean', 'Sum'] + + def __init__(self, model, input_record, inner_shape, reducer, + weight_init=None, weight_optim=None, + name='sparse_lookup', **kwargs): + super(SparseLookup, self).__init__(model, name, input_record, **kwargs) + + if isinstance(inner_shape, int): + inner_shape = [inner_shape] + assert isinstance(inner_shape, list) or isinstance(inner_shape, tuple),\ + "Unexpected type for inner_shape, expected list or tuple, got {0}".\ + format(type(inner_shape)) + + # TODO Add some asserts about input type + assert reducer in self._supported_reducers, "Unsupported reducer: {}".\ + format(reducer) + self.reducer = reducer + + assert input_record.items.metadata is not None,\ + "Features without metadata are not supported" + input_dim = input_record.items.metadata.categorical_limit + assert input_dim is not None, "Unbounded features are not supported" + + self.output_schema = schema.Scalar( + (np.float32, inner_shape), + core.BlobReference(model.net.NextName(self.name + '_output'))) + + scale = math.sqrt(1.0 / input_dim) + self.shape = [input_dim] + inner_shape + self.weight_init = weight_init if weight_init else ( + 'UniformFill', {'min': -scale, 'max': scale}) + + self.w = model.net.NextName(self.name + "_w") + self.params.append( + LayerParameter( + parameter=self.w, + initializer=core.CreateOperator(self.weight_init[0], + [], + self.w, + shape=self.shape, + **self.weight_init[1] + ), + optimizer=weight_optim + )) + + def add_ops(self, net): + if schema.equal_schemas(self.input_record, IdList): + if self.reducer == 'Sum': + net.SparseLengthsSum( + [ + self.w, + self.input_record.items(), + self.input_record.lengths() + ], + self.output_schema.field_blobs() + ) + else: + table_rows = net.Gather([self.w, self.input_record.keys()]) + segments = net.LengthsToRanges(self.input_record.lengths()) + net.__getattr__('SortedSegmentRange' + self.reducer)( + [table_rows, segments], + self.output_schema.field_blobs() + ) + elif schema.equal_schemas(self.input_record, IdScoreList): + if self.reducer == 'Sum': + net.SparseLengthsWeightedSum( + [ + self.w, + self.input_record.values(), + self.input_record.keys(), + self.input_record.lengths() + ], + self.output_schema.field_blobs() + ) + else: + raise "Only Sum is supported for IdScoreList input." +\ + "Trying to create with {}".format(self.reducer) + else: + raise "Unsupported input type {0}".format(self.input_record) diff --git a/caffe2/python/layers/sparse_to_dense.py b/caffe2/python/layers/sparse_to_dense.py new file mode 100644 index 000000000000..f50889a4f440 --- /dev/null +++ b/caffe2/python/layers/sparse_to_dense.py @@ -0,0 +1,131 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, schema +from caffe2.python.layers.layers import ( + ModelLayer, +) +import numpy as np + + +class SparseToDense(ModelLayer): + _known_types = ['FLOAT', 'ID_LIST'] + + def __init__(self, model, input_record, input_specs, + name='sparse_to_dense', **kwargs): + """ + `input_specs` follows the format of FeatureSpec from schema. To be more + precise it's a namedtuple that should have: + 'feature_type', 'feature_names', 'feature_ids' + """ + super(SparseToDense, self).__init__(model, name, + input_record, **kwargs) + + self.input_specs = input_specs + + outputs = [] + for field, feature_specs in self.input_specs: + assert len(feature_specs.feature_names) ==\ + len(feature_specs.feature_ids) + if feature_specs.feature_type == 'FLOAT': + outputs.append(( + field, + schema.Scalar( + (np.float32, len(feature_specs.feature_ids)), + core.BlobReference( + model.net.NextName(self.name + field + '_output')) + ) + )) + elif feature_specs.feature_type == 'ID_LIST': + outputs.append(( + field, + schema.Struct( + ('ranges', + schema.Scalar( + ( + np.int32, + (len(feature_specs.feature_ids), 2) + ), + core.BlobReference( + model.net.NextName( + self.name + field + '_ranges') + ) + ), + ), + ('values', input_record[field].values.items), + ) + )) + else: + raise TypeError( + "Unsupported input type: {0}". + format(feature_specs.feature_type)) + + # TODO(amalevich): This schema is producing ranges. And thus if there is + # something using it it should support ranges as well. It might be + # confusing, if we don't add better support for ranges/have it as a + # first layer + self.output_schema = schema.Struct( + *outputs + ) + + # TODO(amalevich): Consider moving this data to schema, instead + # Structs doens't support attaching metadata to them and clonning + # will break things badly, but this is the most elegant way to pass + # this info around. Should we change it or it'll be too much work and + # not worse it? + """ + for field, feature_specs in input_specs: + self.output_schema[field].set_metadata( + schema.Metadata( + categorical_limit=None, + expected_value=None, + feature_specs=feature_specs + ) + ) + """ + self.zero = model.global_constants['ZERO'] + self.zero_range = model.global_constants['ZERO_RANGE'] + + # Add operators to all types that need to be densified + def add_ops(self, net): + record = self.input_record + for field, feature_specs in self.input_specs: + if feature_specs.feature_type == 'FLOAT': + net.SparseToDenseMask( + [ + record[field].keys(), + record[field].values(), + self.zero, + record[field].lengths(), + ], + [ + self.output_schema[field](), + ], + mask=feature_specs.feature_ids, + ) + elif feature_specs.feature_type == 'ID_LIST': + id_list_ranges = net.LengthsToRanges( + record[field].values.lengths(), 1 + ) + net.SparseToDenseMask( + [ + record[field].keys(), id_list_ranges, self.zero_range, + record[field].lengths() + ], + self.output_schema[field].ranges(), + mask=feature_specs.feature_ids, + ) + + def get_metadata(self): + metadata = [] + for field, feature_specs in self.input_specs: + metadata.append( + ( + feature_specs, + self.output_schema[field].field_blobs(), + self.output_schema[field].field_types() + ) + ) + return metadata diff --git a/caffe2/python/layers/tags.py b/caffe2/python/layers/tags.py new file mode 100644 index 000000000000..da1a06febdfb --- /dev/null +++ b/caffe2/python/layers/tags.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import context + + +@context.define_context(allow_default=True) +class TagContext(object): + """ + Scope driven way to provide tags to the layers. + """ + + def __init__(self, tags=None): + # Tags is expected to be list to keep order of adding/removing things + self.tags = tags or [] + + def add_tags(self, tags): + self.tags.extend(tags) + + def remove_tags(self, tags): + assert self.tags[-len(tags):] == tags + self.tags = self.tags[:-len(tags)] + + +class Tags(object): + # TODO(amalevich): Tags might need to live in their own contexts, add this + # split later + TRAIN_ONLY = 'train_only' + PREPROCESSING = 'preprocessing' + + # In certain cases we want to have different schema for training and + # prediction, as an example in prediction we might need to have only + # subset of ids present in the orignal schema. This tag is one of the ways + # to mark operators that will be removed from prediction and should + # override schema for predictors. + PREDICTION_SCHEMA = 'prediction_schema' + + def __init__(self, tags): + if not isinstance(tags, list): + tags = [tags] + self.tags = tags + + def __enter__(self): + TagContext.current().add_tags(self.tags) + return self + + def __exit__(self, type, value, traceback): + TagContext.current().remove_tags(self.tags) diff --git a/caffe2/python/load_save_test.py b/caffe2/python/load_save_test.py index 6a324f1aae4b..572b99a6f9b2 100644 --- a/caffe2/python/load_save_test.py +++ b/caffe2/python/load_save_test.py @@ -26,6 +26,7 @@ else: # Inherit from this test instead. If you add a test here, # each derived class will inherit it as well and cause test duplication class TestLoadSaveBase(test_util.TestCase): + def __init__(self, methodName, db_type='minidb'): super(TestLoadSaveBase, self).__init__(methodName) self._db_type = db_type @@ -35,7 +36,7 @@ class TestLoadSaveBase(test_util.TestCase): dst_device_type=st.sampled_from(DEVICES), dst_gpu_id=st.integers(min_value=0, max_value=max_gpuid)) def load_save(self, src_device_type, src_gpu_id, - dst_device_type, dst_gpu_id): + dst_device_type, dst_gpu_id): workspace.ResetWorkspace() dtypes = [np.float16, np.float32, np.float64, np.bool, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16] @@ -65,15 +66,16 @@ class TestLoadSaveBase(test_util.TestCase): workspace.ResetWorkspace() self.assertEqual(len(workspace.Blobs()), 0) - def _LoadTest(keep_device, device_type, gpu_id): + def _LoadTest(keep_device, device_type, gpu_id, blobs, loadAll): """A helper subfunction to test keep and not keep.""" op = core.CreateOperator( "Load", - [], [str(i) for i in range(len(arrays))], + [], blobs, absolute_path=1, db=os.path.join(tmp_folder, "db"), db_type=self._db_type, device_option=dst_device_option, - keep_device=keep_device) + keep_device=keep_device, + load_all=loadAll) self.assertTrue(workspace.RunOperatorOnce(op)) for i, arr in enumerate(arrays): self.assertTrue(workspace.HasBlob(str(i))) @@ -90,17 +92,28 @@ class TestLoadSaveBase(test_util.TestCase): self.assertEqual(proto.tensor.device_detail.cuda_gpu_id, gpu_id) + blobs = [str(i) for i in range(len(arrays))] # Load using device option stored in the proto, i.e. # src_device_option - _LoadTest(1, src_device_type, src_gpu_id) + _LoadTest(1, src_device_type, src_gpu_id, blobs, 0) # Load again, but this time load into dst_device_option. - _LoadTest(0, dst_device_type, dst_gpu_id) + _LoadTest(0, dst_device_type, dst_gpu_id, blobs, 0) # Load back to the src_device_option to see if both paths are able # to reallocate memory. - _LoadTest(1, src_device_type, src_gpu_id) + _LoadTest(1, src_device_type, src_gpu_id, blobs, 0) # Reset the workspace, and load directly into the dst_device_option. workspace.ResetWorkspace() - _LoadTest(0, dst_device_type, dst_gpu_id) + _LoadTest(0, dst_device_type, dst_gpu_id, blobs, 0) + + # Test load all which loads all blobs in the db into the workspace. + workspace.ResetWorkspace() + _LoadTest(1, src_device_type, src_gpu_id, [], 1) + # Load again making sure that overwrite functionality works. + _LoadTest(1, src_device_type, src_gpu_id, [], 1) + # Load again with different device. + _LoadTest(0, dst_device_type, dst_gpu_id, [], 1) + workspace.ResetWorkspace() + _LoadTest(0, dst_device_type, dst_gpu_id, [], 1) finally: # clean up temp folder. try: @@ -111,9 +124,30 @@ class TestLoadSaveBase(test_util.TestCase): class TestLoadSave(TestLoadSaveBase): + def testLoadSave(self): self.load_save() + def testRepeatedArgs(self): + dtypes = [np.float16, np.float32, np.float64, np.bool, np.int8, + np.int16, np.int32, np.int64, np.uint8, np.uint16] + arrays = [np.random.permutation(6).reshape(2, 3).astype(T) + for T in dtypes] + + for i, arr in enumerate(arrays): + self.assertTrue(workspace.FeedBlob(str(i), arr)) + self.assertTrue(workspace.HasBlob(str(i))) + + # Saves the blobs to a local db. + tmp_folder = tempfile.mkdtemp() + op = core.CreateOperator( + "Save", + [str(i) for i in range(len(arrays))] * 2, [], + absolute_path=1, + db=os.path.join(tmp_folder, "db"), db_type=self._db_type) + with self.assertRaises(RuntimeError): + self.assertRaises(workspace.RunOperatorOnce(op)) + if __name__ == '__main__': unittest.main() diff --git a/caffe2/python/model_helper.py b/caffe2/python/model_helper.py index 26b8b97f3d27..59e6f2be5669 100644 --- a/caffe2/python/model_helper.py +++ b/caffe2/python/model_helper.py @@ -2,11 +2,52 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals -from caffe2.python import core + +from caffe2.python import core, scope +import numpy as np import logging +class ParameterType(object): + DENSE = 'dense' + SPARSE = 'sparse' + + +class ParameterInfo(object): + def __init__( + self, param_id, param, key=None, shape=None, length=None): + assert isinstance(param, core.BlobReference) + self.param_id = param_id + self.name = str(param) + self.blob = param + self.key = key + self.shape = shape + self.size = None if shape is None else np.prod(shape) + self.length = max(1, length if length is not None else 1) + self.grad = None + self._cloned_init_net = None + + def grad_type(self): + assert self.grad is not None, ( + 'Gradient not defined for parameter %s' % self.name) + return ( + ParameterType.SPARSE if isinstance(self.grad, core.GradientSlice) + else ParameterType.DENSE) + + def cloned_init_net(self): + if not self._cloned_init_net: + init_net, outputs = self.blob.Net().ClonePartial( + 'param_%d_%s_init' % (self.param_id, self.name), + inputs=[], + outputs=[self.blob]) + self._cloned_init_net = (init_net, outputs[0]) + return self._cloned_init_net + + def __str__(self): + return self.name + + class ModelHelperBase(object): """A helper model so we can write models more easily, without having to manually define parameter initializations and operators separately. @@ -23,10 +64,79 @@ class ModelHelperBase(object): self.param_to_grad = {} self.params = [] + self._param_info = [] + self._devices = [] self.gradient_ops_added = False self.init_params = init_params self.allow_not_known_ops = allow_not_known_ops + def _infer_param_shape(self, param): + for op in self.param_init_net.Proto().op: + if str(param) in op.output: + for arg in op.arg: + if arg.name == "shape": + return list(arg.ints) + return None + + def _update_param_info(self): + assert len(self._param_info) <= len(self.params) + for param in self.params[len(self._param_info):]: + if not isinstance(param, core.BlobReference): + param = core.BlobReference(str(param), net=self._param_init_net) + self._param_info.append(ParameterInfo( + param_id=len(self._param_info), + param=param, + shape=self._infer_param_shape(param))) + for info in self._param_info: + info.grad = self.param_to_grad.get(info.name) + + def add_param(self, param, key=None, shape=None, length=None): + self._update_param_info() + if key is not None: + idx = self.net.input_record().field_blobs().index(key) + key = self.net.input_record().field_names()[idx] + shape = shape if shape is not None else self._infer_param_shape(param) + self.params.append(param) + if not isinstance(param, core.BlobReference): + param = core.BlobReference(str(param), net=self._param_init_net) + self._param_info.append(ParameterInfo( + param_id=len(self._param_info), + param=param, + shape=shape, + key=key, + length=length, + )) + return self._param_info[-1] + + def param_info(self, grad_type=None, id=None): + self._update_param_info() + if id is not None: + assert grad_type is None + info = self._param_info[id] + assert info.param_id == id + return info + elif grad_type is not None: + return [ + info for info in self._param_info + if info.grad_type() == grad_type] + else: + return self._param_info + + def GetParams(self, namescope=None): + ''' + Returns the params in current namescope + ''' + if namescope is None: + namescope = scope.CurrentNameScope() + else: + if not namescope.endswith(scope._NAMESCOPE_SEPARATOR): + namescope += scope._NAMESCOPE_SEPARATOR + + if namescope == '': + return self.params[:] + else: + return [p for p in self.params if p.GetNameScope() == namescope] + def Proto(self): return self.net.Proto() @@ -46,10 +156,14 @@ class ModelHelperBase(object): if self.gradient_ops_added: raise RuntimeError("You cannot run AddGradientOperators twice.") self.gradient_ops_added = True + + # We need to use empty namescope when creating the gradients + # to prevent duplicating the namescope prefix for gradient blobs. grad_map = self.net.AddGradientOperators(*args, **kwargs) for p in self.params: if str(p) in grad_map: self.param_to_grad[p] = grad_map[str(p)] + return grad_map def TensorProtosDBInput( @@ -89,8 +203,16 @@ class ModelHelperBase(object): self.params.extend(parameters) return self.net.__getattr__(op_type)(inputs, *args, **kwargs) + def GetDevices(self): + assert len(self._devices) > 0, \ + "Use data_parallel_model to run model on multiple GPUs." + return self._devices + def __getattr__(self, op_type): """Catch-all for all other operators, mostly those without params.""" + if op_type.startswith('__'): + raise AttributeError(op_type) + if not core.IsOperator(op_type): raise RuntimeError( 'Method ' + op_type + ' is not a registered operator.' @@ -99,29 +221,32 @@ class ModelHelperBase(object): known_working_ops = [ "Accuracy", "Adam", + "Add", "AveragedLoss", "Cast", + "ConstantFill", + "DequeueBlobs", "EnsureCPUOutput", + "FlattenToVec", "LabelCrossEntropy", "LearningRate", + "MakeTwoClass", + "NCCLAllreduce", + "NHWC2NCHW", "Print", - "Sigmoid", "Scale", + "ScatterWeightedSum", + "Sigmoid", "Snapshot", "Softmax", + "SoftmaxWithLoss", + "SquaredL2Distance", + "Squeeze", "StopGradient", "Summarize", "Tanh", + "PRelu", "WeightedSum", - "SquaredL2Distance", - "FlattenToVec", - "NHWC2NCHW", - "ScatterWeightedSum", - "Squeeze", - "NCCLAllreduce", - "ConstantFill", - "Add", - "DequeueBlobs", ] if op_type not in known_working_ops: assert self.allow_not_known_ops diff --git a/caffe2/python/models/resnet.py b/caffe2/python/models/resnet.py index 30caa77748d8..d395b6c4b064 100644 --- a/caffe2/python/models/resnet.py +++ b/caffe2/python/models/resnet.py @@ -182,7 +182,7 @@ class ResNetBuilder(): self.comp_count += 1 -def create_resnet50(model, data, num_input_channels, num_labels): +def create_resnet50(model, data, num_input_channels, num_labels, label=None): # conv1 + maxpool model.Conv(data, 'conv1', num_input_channels, 64, weight_init=("MSRAFill", {}), kernel=7, stride=2, pad=3) model.SpatialBN('conv1', 'conv1_spatbn', 64, epsilon=1e-3) @@ -218,9 +218,17 @@ def create_resnet50(model, data, num_input_channels, num_labels): # Final dimension of the "image" is reduced to 7x7 model.FC('final_avg', 'pred', 2048, num_labels) - softmax = model.Softmax('pred', 'softmax') - return softmax + # If we create model for training, use softmax-with-loss + if (label is not None): + (softmax, loss) = model.SoftmaxWithLoss( + ["pred", label], + ["softmax", "loss"], + ) + return (softmax, loss) + else: + # For inference, we just return softmax + return model.Softmax("pred", "softmax") def create_resnet_32x32( model, data, num_input_channels, num_groups, num_labels diff --git a/caffe2/python/net_builder.py b/caffe2/python/net_builder.py new file mode 100644 index 000000000000..eafec5d6cbb9 --- /dev/null +++ b/caffe2/python/net_builder.py @@ -0,0 +1,251 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, context + + +@context.define_context() +class NetBuilder(object): + """ + Scope-driven mechanism for building nets, loops and conditional blocks. + Example: + from caffe2.python.net_builder import NetBuilder, ops + with NetBuilder() as nb: + c = ops.Const(5) + d = ops.Const(0) + with ops.loop(): + ops.stop_if(ops.LE([c, ops.Const(0)])) + ops.Add([c, ops.Const(-1)], [c]) + with ops.If(ops.GE([c, ops.Const(3)])): + ops.Add([d, ops.Const(10)]) + ops.Print(c, []) + ops.Print(d, []) + step = core.to_execution_step(nb) + """ + def __init__(self, name=None, _stop_blob_required=False): + self._name = name or '' + self._prefix = name + '/' if name else '' + self._frozen = False + self._current_net = None + self._children = [] + self._stop_blob = None + self._stop_blob_required = _stop_blob_required + + def stop_blob(self): + """ + Returns the BlobReference to the stop_blob of this NetBuilder. + If one is not yet available, creates one. + This function assumes that the stop_blob() will be used immediatelly + in the current net, so it doesn't initialize it if the current net is + the first of the builder. + """ + if self._stop_blob is None: + net = self.current_net() + self._stop_blob = core.BlobReference( + net.NextName('stop_blob'), net=net) + if self._current_net != self._children[0]: + self._children.insert(0, core.Net( + self._prefix + 'stop_blob_init')) + self._children[0].Const(False, blob_out=self._stop_blob) + return self._stop_blob + + def stop_if(self, blob): + ops.Copy(blob, self.stop_blob()) + self._current_net = None + + def _assert_mutable(self): + assert not self._frozen, ( + 'This NetBuilder (%s) has been built already.' % self._name) + + def add(self, child): + self._assert_mutable() + self._current_net = None + self._children.append(child) + # to-do : check it's not a dag net + if isinstance(child, core.Net): + self._current_net = child + return child + + def current_net(self): + self._assert_mutable() + if self._current_net is None: + self.add(core.Net(self._prefix + 'net')) + return self._current_net + + def freeze(self): + for child in self._children: + if hasattr(child, 'freeze'): + child.freeze() + self._current_net = None + self._frozen = True + + def get(self): + self.freeze() + return self._children + + def __exit__(self, etype, *args): + self.freeze() + if etype is not None: + return + assert (not self._stop_blob_required) or self._stop_blob is not None, ( + 'This NetBuilder (%s) requires a stop condition ' % self._name + + 'to be set with `stop` or `stop_if`') + + +class Operations(object): + """ + Operations to be used in the context of a NetBuilder. + """ + def net(self, net=None): + """ + Retrieves the current net, or add a new net to the builder. + """ + if net is not None: + NetBuilder.current().add(net) + return net + return NetBuilder.current().current_net() + + def __getattr__(self, op_type): + """ + Adds an operator call to the currently active Net. + """ + if op_type.startswith('__'): + raise AttributeError() + return getattr(self.net(), op_type) + + def task_group(self): + """ + Creates a local task group which will execute as the next step of + the current NetBuilder. + """ + from caffe2.python import task + group = NetBuilder.current() + with task.Cluster(): + with task.Node('local'): + tg = task.TaskGroup() + group.add(tg) + return tg + + def stop(self): + """ + Stop execution of the current execution step. + Example: + ops.Print(a, 0) + ops.stop() + ops.Print(b, 0) + In the example, 'b' will never be printed. + """ + return self.stop_if(ops.Const(True)) + + def stop_if(self, blob): + """ + Stop execution of the current execution step if the + condition `blob` is met. + Example: + ops.Print(a, 0) + ops.stop_if(ops.LE([x, ops.Const(0)])) + ops.Print(b, 0) + In the example, 'b' will only be printed if the value of scalar + tensor 'x' lower or equal to 0. + """ + return NetBuilder.current().stop_if(blob) + + def loop(self): + """ + Creates a NetBuilder that will execute in a loop as the next step of + the current NetBuilder. + Example: + a = ops.Const(5) + with ops.loop(): + ops.stop_if(ops.LE([a, ops.Const(0)])) + ops.Print(a, 0) + ops.Add([a, ops.Const(-1)], [a]) + In the example, 'a' will be printed 5 times, with values 5 to 1. + """ + return NetBuilder.current().add(NetBuilder(_stop_blob_required=True)) + + def stop_guard(self, has_stopped_blob=None): + """ + Creates a NetBuilder that will execute once as the next step of the + current NetBuilder. After execution, a bool tensor will indicate + whether the inner execution was halted with `stop` or `stop_if`. + Example: + a = ops.Const(True) + with ops.stop_guard() as sg1: + ops.stop_if(a) + ops.Print(ops.Const('did not stop')) + b = ops.Const(False) + with ops.stop_guard() as sg2: + ops.stop_if(b) + ops.Print(ops.Const('did not stop')) + ops.Print(sg1.has_stopped(), []) + ops.Print(sg2.has_stopped(), []) + In the example, 'did not stop' will be printed once, + followed by True and False. + """ + return NetBuilder.current().add( + _StopGuard(has_stopped_blob=has_stopped_blob)) + + def If(self, cond): + """ + Creates a NetBuilder that will execute once as the next step of the + current NetBuilder if the blob `cond` is True. + Example: + with ops.If(ops.Const(True)): + ops.Print(ops.Const('Will print')) + with ops.If(ops.Const(False)): + ops.Print(ops.Const('Wont print')) + The example will print 'Will print' once. + """ + return NetBuilder.current().add(_RunIf(cond)) + + +ops = Operations() + + +class _RunOnce(NetBuilder): + def __init__(self, name=None): + NetBuilder.__init__(self, name) + + def __exit__(self, *args): + if self._stop_blob is not None: + ops.stop() + NetBuilder.__exit__(self, *args) + + +class _StopGuard(_RunOnce): + def __init__(self, name=None, has_stopped_blob=None): + _RunOnce.__init__(self, name) + self._stopped = has_stopped_blob + self._ran = False + + def __enter__(self): + r = _RunOnce.__enter__(self) + self._stopped = ops.Const(True, blob_out=self._stopped) + return r + + def __exit__(self, *args): + self._ran = True + ops.Const(False, blob_out=self._stopped) + _RunOnce.__exit__(self, args) + + def has_stopped(self): + """ + Return a blob that will be set to scalar bool `True` after + this net builder ran, iff it was halted early. + """ + assert self._ran, 'Context not used yet.' + return self._stopped + + +class _RunIf(_RunOnce): + def __init__(self, cond_blob, name=None): + _RunOnce.__init__(self, name) + self._cond_blob = cond_blob + + def __enter__(self): + r = _RunOnce.__enter__(self) + ops.stop_if(self._cond_blob) + return r diff --git a/caffe2/python/net_builder_test.py b/caffe2/python/net_builder_test.py new file mode 100644 index 000000000000..801e17ff9ba1 --- /dev/null +++ b/caffe2/python/net_builder_test.py @@ -0,0 +1,82 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import workspace +from caffe2.python.core import Plan, to_execution_step +from caffe2.python.net_builder import ops, NetBuilder +import unittest + + +def test_loop(): + x = ops.Const(5) + y = ops.Const(0) + with ops.loop(): + ops.stop_if(ops.EQ([x, ops.Const(0)])) + ops.Add([x, ops.Const(-1)], [x]) + ops.Add([y, ops.Const(1)], [y]) + return y + + +def test_inner_stop(x): + ops.stop_if(ops.LT([x, ops.Const(5)])) + + +def test_outer(): + x = ops.Const(10) + # test stop_if(False) + with ops.stop_guard() as g1: + test_inner_stop(x) + + # test stop_if(True) + y = ops.Const(3) + with ops.stop_guard() as g2: + test_inner_stop(y) + + # test no stop + with ops.stop_guard() as g4: + ops.Const(0) + + # test empty clause + with ops.stop_guard() as g3: + pass + + return ( + g1.has_stopped(), g2.has_stopped(), g3.has_stopped(), g4.has_stopped()) + + +def test_if(x): + y = ops.Const(1) + with ops.If(ops.GT([x, ops.Const(50)])): + ops.Const(2, blob_out=y) + with ops.If(ops.LT([x, ops.Const(50)])): + ops.Const(3, blob_out=y) + ops.stop() + ops.Const(4, blob_out=y) + return y + + +class TestNetBuilder(unittest.TestCase): + def test_ops(self): + with NetBuilder() as nb: + y = test_loop() + z, w, a, b = test_outer() + p = test_if(ops.Const(75)) + q = test_if(ops.Const(25)) + plan = Plan('name') + plan.AddStep(to_execution_step(nb)) + ws = workspace.C.Workspace() + ws.run(plan) + expected = [ + (y, 5), + (z, False), + (w, True), + (a, False), + (b, False), + (p, 3), + (q, 2), + ] + for b, expected in expected: + actual = ws.blobs[str(b)].fetch() + self.assertEquals(actual, expected) diff --git a/caffe2/python/net_drawer.py b/caffe2/python/net_drawer.py index 9645d851d422..fc91eeed023f 100644 --- a/caffe2/python/net_drawer.py +++ b/caffe2/python/net_drawer.py @@ -93,7 +93,8 @@ def GetPydotGraph( for input_name in op.input: if input_name not in pydot_nodes: input_node = pydot.Node( - input_name + str(pydot_node_counts[input_name]), + _escape_label( + input_name + str(pydot_node_counts[input_name])), label=_escape_label(input_name), **BLOB_STYLE ) @@ -107,7 +108,8 @@ def GetPydotGraph( # we are overwriting an existing blob. need to updat the count. pydot_node_counts[output_name] += 1 output_node = pydot.Node( - output_name + str(pydot_node_counts[output_name]), + _escape_label( + output_name + str(pydot_node_counts[output_name])), label=_escape_label(output_name), **BLOB_STYLE ) @@ -199,6 +201,8 @@ def _draw_steps(steps, g, skip_step_edges=False): # noqa label.append('Stopper: {}'.format(step.should_stop_blob)) if step.concurrent_substeps: label.append('Concurrent') + if step.only_once: + label.append('Once') return '\n'.join(label) def substep_edge(start, end): diff --git a/caffe2/python/op/python.py b/caffe2/python/op/python.py deleted file mode 100644 index 4fcc36450bc2..000000000000 --- a/caffe2/python/op/python.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals -from caffe2.python import core - -from caffe2.python.op.python_ops_python import \ - register, register_gradient, TensorCPU - - -def _TensorCPU_shape(self): - return tuple(self._shape) - - -def _TensorCPU_reshape(self, shape): - return self._reshape(list(shape)) - -TensorCPU.shape = property(_TensorCPU_shape) -TensorCPU.reshape = _TensorCPU_reshape - - -def CreatePythonOperator(f, inputs, outputs, grad_f=None, *args, **kwargs): - token = register(f) - if grad_f: - register_gradient(token, grad_f) - kwargs["token"] = token - return core.CreateOperator("Python", inputs, outputs, *args, **kwargs) diff --git a/caffe2/python/op/python_op.cpp b/caffe2/python/op/python_op.cpp deleted file mode 100644 index 1c950afac868..000000000000 --- a/caffe2/python/op/python_op.cpp +++ /dev/null @@ -1,206 +0,0 @@ -#include - -#include "caffe2/core/context.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/tensor.h" - -#include -#include - -// Produce deprecation warnings (needs to come before arrayobject.h inclusion). -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - -#include - -// Temporary solution for numpy < 1.7 versions: old macro, no promises. -// You're strongly advised to upgrade to >= 1.7. -#ifndef NPY_ARRAY_C_CONTIGUOUS -#define NPY_ARRAY_C_CONTIGUOUS NPY_C_CONTIGUOUS -#define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x)) -#endif - -namespace caffe2 { - -namespace py = pybind11; - -namespace { - -using FuncRegistery = std::unordered_map; -static FuncRegistery& gRegistery() { - // Always leak the objects registered here. - static FuncRegistery* r = new FuncRegistery(); - return *r; -} - -py::object& getOpFunc(const std::string& token) { - return gRegistery()[token]; -} - -py::object& getGradientFunc(const std::string& token) { - return gRegistery()[token + "_gradient"]; -} - -} - -class PythonOpBase : public Operator { - public: - using Operator::Operator; - - bool RunOnDevice() final { - std::vector inputs; - inputs.reserve(InputSize()); - for (auto i = 0; i < InputSize(); ++i) { - inputs.push_back(const_cast(&Input(i))); - } - std::vector outputs; - outputs.reserve(OutputSize()); - for (auto i = 0; i < OutputSize(); ++i) { - outputs.push_back(Output(i)); - } - auto& pyFunc = getFunc(); - { - // Acquire GIL for call to Python runtime. - py::gil_scoped_acquire g; - try { - pyFunc(inputs, outputs); - } catch (const py::error_already_set& e) { - LOG(ERROR) << "Exception encountered running PythonOp function: " - << e.what() << "\nTraceback: "; - PyObject *type = nullptr, *value = nullptr, *trace = nullptr; - PyErr_Fetch(&type, &value, &trace); - PyTracebackObject* traceback = - reinterpret_cast(trace); - vector trace_vec; - while (traceback) { - trace_vec.push_back(traceback); - traceback = traceback->tb_next; - } - for (int i = trace_vec.size() - 1; i >= 0; --i) { - int line = trace_vec[i]->tb_lineno; - const char* filename = - PyString_AsString(trace_vec[i]->tb_frame->f_code->co_filename); - const char* funcname = - PyString_AsString(trace_vec[i]->tb_frame->f_code->co_name); - LOG(ERROR) << " # " << trace_vec.size() - i - 1 << " " << filename - << " (" << line << "): " << funcname; - } - Py_XDECREF(type); - Py_XDECREF(value); - Py_XDECREF(trace); - return false; - } - } - return true; - } - - private: - virtual py::object& getFunc() = 0; -}; - -class PythonOp final : public PythonOpBase { - public: - using PythonOpBase::PythonOpBase; - - private: - py::object& getFunc() override { - const std::string& token = - OperatorBase::GetSingleArgument("token", ""); - return getOpFunc(token); - } -}; - -class PythonGradientOp final : public PythonOpBase { - public: - using PythonOpBase::PythonOpBase; - - private: - py::object& getFunc() override { - const std::string& token = - OperatorBase::GetSingleArgument("token", ""); - return getGradientFunc(token); - } -}; - -PYBIND11_PLUGIN(python_ops_python) { - py::module m("python_ops_python", "pybind11 interface to operators"); - - py::class_(m, "TensorCPU") - .def_property_readonly( - "data", - [](TensorCPU* t) -> py::object { - CAFFE_ENFORCE(t->size() > 0); - std::vector npy_dims; - for (const auto dim : t->dims()) { - npy_dims.push_back(dim); - } - PyObject* array = PyArray_SimpleNewFromData( - t->ndim(), - npy_dims.data(), - NPY_FLOAT32, - t->mutable_data()); - return py::object(array, /* borrowed= */ false); - }) - .def_property_readonly( - "_shape", [](const TensorCPU& t) { return t.dims(); }) - .def("_reshape", [](TensorCPU* t, std::vector dims) { - t->Resize(dims); - }); - - m.def("register", [](py::object func) { - CAFFE_ENFORCE(func != py::none()); - const std::string name = func.attr("__name__").cast(); - // Unique name since registry is never cleared. - const std::string token = name + to_string(gRegistery().size()); - CAFFE_ENFORCE(gRegistery().find(name) == gRegistery().end()); - gRegistery()[token] = func; - return token; - }); - - m.def("register_gradient", [](const std::string& token, py::object func) { - CAFFE_ENFORCE(func != py::none()); - CAFFE_ENFORCE(gRegistery().find(token) != gRegistery().end()); - gRegistery()[token + "_gradient"] = func; - }); - ([]() { - // This is a workaround so we can deal with numpy's import_array behavior. - // Despite the fact that you may think import_array() is a function call, - // it is defined as a macro (as of 1.10). - import_array(); - })(); - return m.ptr(); -} - -namespace { - -struct GetPythonGradient : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - std::vector GetGradientDefs() override { - std::vector gradientInputs; - for (int i = 0; i < def_.input_size(); ++i) { - gradientInputs.push_back(I(i)); - } - for (int i = 0; i < def_.output_size(); ++i) { - gradientInputs.push_back(O(i)); - } - for (int i = 0; i < def_.output_size(); ++i) { - gradientInputs.push_back(GO(i)); - } - std::vector gradientOutputs; - for (int i = 0; i < def_.input_size(); ++i) { - gradientOutputs.push_back(GI(i)); - } - - return SingleGradientDef( - "PythonGradient", "", gradientInputs, gradientOutputs); - } -}; - -REGISTER_CPU_OPERATOR(Python, PythonOp); -REGISTER_CPU_OPERATOR(PythonGradient, PythonGradientOp); -// Always allow running in-place -OPERATOR_SCHEMA(Python).AllowInplace([](int, int) { return true; }); -OPERATOR_SCHEMA(PythonGradient).AllowInplace([](int, int) { return true; }); - -REGISTER_GRADIENT(Python, GetPythonGradient); -} -} diff --git a/caffe2/python/operator_test/activation_ops_test.py b/caffe2/python/operator_test/activation_ops_test.py new file mode 100644 index 000000000000..e309744f538a --- /dev/null +++ b/caffe2/python/operator_test/activation_ops_test.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np + +from hypothesis import given +import hypothesis.strategies as st + +from caffe2.python import core +import caffe2.python.hypothesis_test_util as hu + + +class TestActivations(hu.HypothesisTestCase): + @given(X=hu.tensor(), + alpha=st.floats(min_value=0.1, max_value=2.0), + inplace=st.booleans(), + **hu.gcs_cpu_only) + def test_elu(self, X, alpha, inplace, gc, dc): + # go away from the origin point to avoid kink problems + X += 0.04 * np.sign(X) + X[X == 0.0] += 0.04 + + def elu_ref(X): + Y = X.copy() + neg_indices = X <= 0 + Y[neg_indices] = alpha * (np.exp(Y[neg_indices]) - 1) + return (Y,) + + op = core.CreateOperator( + "Elu", + ["X"], ["Y" if not inplace else "X"], + alpha=alpha) + self.assertReferenceChecks(gc, op, [X], elu_ref) + # Check over multiple devices + self.assertDeviceChecks(dc, op, [X], [0]) + # Gradient check wrt X + self.assertGradientChecks(gc, op, [X], 0, [0]) + + @given(X=hu.tensor(min_dim=4, max_dim=4), + alpha=st.floats(min_value=0.1, max_value=2.0), + inplace=st.booleans(), + shared=st.booleans(), + order=st.sampled_from(["NCHW", "NHWC"]), + **hu.gcs_cpu_only) + def test_prelu(self, X, alpha, inplace, shared, order, gc, dc): + np.random.seed(20) + W = np.random.randn( + X.shape[1] if order == "NCHW" else X.shape[3]).astype(np.float32) + + if shared: + W = np.random.randn(1).astype(np.float32) + + # go away from the origin point to avoid kink problems + X += 0.04 * np.sign(X) + X[X == 0.0] += 0.04 + + def prelu_ref(X, W): + Y = X.copy() + W = W.reshape(1, -1, 1, 1) if order == "NCHW" \ + else W.reshape(1, 1, 1, -1) + assert len(X.shape) == 4 + neg_indices = X <= 0 + assert len(neg_indices.shape) == 4 + assert X.shape == neg_indices.shape + Y[neg_indices] = (Y * W)[neg_indices] + return (Y,) + + op = core.CreateOperator( + "PRelu", ["X", "W"], ["Y" if not inplace else "X"], + alpha=alpha, order=order) + self.assertReferenceChecks(gc, op, [X, W], prelu_ref) + # Check over multiple devices + self.assertDeviceChecks(dc, op, [X, W], [0]) + + if not inplace: + # Gradient check wrt X + self.assertGradientChecks(gc, op, [X, W], 0, [0]) + # Gradient check wrt W + self.assertGradientChecks(gc, op, [X, W], 1, [0]) diff --git a/caffe2/python/operator_test/conv_transpose_test.py b/caffe2/python/operator_test/conv_transpose_test.py index b7b008f4cf7c..82a63ac57ed1 100644 --- a/caffe2/python/operator_test/conv_transpose_test.py +++ b/caffe2/python/operator_test/conv_transpose_test.py @@ -19,7 +19,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase): input_channels=st.integers(1, 8), output_channels=st.integers(1, 8), batch_size=st.integers(1, 3), - engine=st.sampled_from(["", "CUDNN"]), + engine=st.sampled_from(["", "CUDNN", "BLOCK"]), shared_buffer=st.booleans(), **hu.gcs) def test_convolution_transpose_layout(self, stride, pad, kernel, adj, @@ -83,7 +83,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase): input_channels=st.integers(1, 8), output_channels=st.integers(1, 8), batch_size=st.integers(1, 3), - engine=st.sampled_from([""]), **hu.gcs) + engine=st.sampled_from(["", "BLOCK"]), **hu.gcs) def test_convolution_transpose_separate_stride_pad_adj_layout( self, stride_h, stride_w, pad_t, pad_l, pad_b, pad_r, kernel, adj_h, adj_w, size, input_channels, output_channels, batch_size, @@ -146,7 +146,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase): output_channels=st.integers(1, 8), batch_size=st.integers(1, 3), order=st.sampled_from(["NCHW", "NHWC"]), - engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) + engine=st.sampled_from(["", "CUDNN", "BLOCK"]), **hu.gcs) @settings(max_examples=2, timeout=100) def test_convolution_transpose_gradients(self, stride, pad, kernel, adj, size, input_channels, @@ -193,7 +193,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase): output_channels=st.integers(1, 8), batch_size=st.integers(1, 3), order=st.sampled_from(["NCHW", "NHWC"]), - engine=st.sampled_from([""]), **hu.gcs) + engine=st.sampled_from(["", "BLOCK"]), **hu.gcs) @settings(max_examples=2, timeout=100) def test_convolution_transpose_separate_stride_pad_adj_gradient( self, stride_h, stride_w, pad_t, pad_l, pad_b, pad_r, kernel, diff --git a/caffe2/python/operator_test/counter_ops_test.py b/caffe2/python/operator_test/counter_ops_test.py index ad942a6cb21c..e446f0665c95 100644 --- a/caffe2/python/operator_test/counter_ops_test.py +++ b/caffe2/python/operator_test/counter_ops_test.py @@ -5,9 +5,11 @@ from __future__ import unicode_literals from caffe2.python import core, workspace from caffe2.python.test_util import TestCase +import tempfile class TestCounterOps(TestCase): + def test_counter_ops(self): workspace.RunOperatorOnce(core.CreateOperator( 'CreateCounter', [], ['c'], init_count=1)) @@ -50,3 +52,22 @@ class TestCounterOps(TestCase): assert workspace.RunOperatorOnce(core.CreateOperator( 'And', ['t2', 't5'], ['t7'])) assert workspace.FetchBlob('t7') # True && True + + workspace.RunOperatorOnce(core.CreateOperator( + 'CreateCounter', [], ['serialized_c'], init_count=22)) + with tempfile.NamedTemporaryFile() as tmp: + workspace.RunOperatorOnce(core.CreateOperator( + 'Save', ['serialized_c'], [], absolute_path=1, + db_type='minidb', db=tmp.name)) + for i in range(10): + workspace.RunOperatorOnce(core.CreateOperator( + 'CountDown', ['serialized_c'], ['t8'])) + workspace.RunOperatorOnce(core.CreateOperator( + 'RetrieveCount', ['serialized_c'], ['t8'])) + assert workspace.FetchBlob('t8') == 12 + workspace.RunOperatorOnce(core.CreateOperator( + 'Load', [], ['serialized_c'], absolute_path=1, + db_type='minidb', db=tmp.name)) + workspace.RunOperatorOnce(core.CreateOperator( + 'RetrieveCount', ['serialized_c'], ['t8'])) + assert workspace.FetchBlob('t8') == 22 diff --git a/caffe2/python/operator_test/dataset_ops_test.py b/caffe2/python/operator_test/dataset_ops_test.py index 83d612876932..371d3d74d9bd 100644 --- a/caffe2/python/operator_test/dataset_ops_test.py +++ b/caffe2/python/operator_test/dataset_ops_test.py @@ -332,16 +332,12 @@ class TestDatasetOps(TestCase): collect_net = core.Net('collect_net') num_to_collect = 1000 max_example_to_cover = 100000 - for i, b in enumerate(blobs): - if i == 0: - bvec_map[b], position = collect_net.CollectTensor( - [bvec_map[b], b], [bvec_map[b], 'position'], - num_to_collect=num_to_collect) - else: - # sample in the same way as the first blob - bvec_map[b], position = collect_net.CollectTensor( - [bvec_map[b], b, position], [bvec_map[b], position], - num_to_collect=num_to_collect) + bvec = [bvec_map[b] for b in blobs] + collect_net.CollectTensor( + bvec + blobs, + bvec, + num_to_collect=num_to_collect, + ) print('Collect Net Proto: {}'.format(collect_net.Proto())) diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py index b656d219f4dd..b29eb3af3002 100644 --- a/caffe2/python/operator_test/matmul_op_test.py +++ b/caffe2/python/operator_test/matmul_op_test.py @@ -20,11 +20,11 @@ class TestMatMul(hu.HypothesisTestCase): trans_b=st.booleans(), **hu.gcs) def test_matmul(self, M, K, N, trans_a, trans_b, gc, dc): - X = np.random.randn(M, K).astype(np.float32) + X = np.random.rand(M, K).astype(np.float32) - 0.5 if trans_a: X = X.transpose() - Y = np.random.randn(K, N).astype(np.float32) + Y = np.random.rand(K, N).astype(np.float32) - 0.5 if trans_b: Y = Y.transpose() @@ -56,11 +56,11 @@ class TestBatchMatMul(hu.HypothesisTestCase): trans_b=st.booleans(), **hu.gcs) def test_matmul(self, C, M, K, N, trans_a, trans_b, gc, dc): - X = np.random.randn(C, M, K).astype(np.float32) + X = np.random.rand(C, M, K).astype(np.float32) - 0.5 if trans_a: X = X.swapaxes(1, 2) - Y = np.random.randn(C, K, N).astype(np.float32) + Y = np.random.rand(C, K, N).astype(np.float32) - 0.5 if trans_b: Y = Y.swapaxes(1, 2) diff --git a/caffe2/python/operator_test/mkl_ops_test.py b/caffe2/python/operator_test/mkl_ops_test.py index 4d3455999734..3d164d22610e 100644 --- a/caffe2/python/operator_test/mkl_ops_test.py +++ b/caffe2/python/operator_test/mkl_ops_test.py @@ -19,6 +19,9 @@ class PackedFCTest(hu.HypothesisTestCase): K=st.integers(128, 1024), N=st.integers(128, 1024), **hu.gcs_cpu_only) + @unittest.skipIf(not core.C.builtin_cpu_supports_avx2(), + "Intel MKL sgemm_pack has a known numerical issue with " + "non-avx2 machines that will be fixed in a later build.") def test_packed_fc(self, seed, M, K, N, gc, dc): np.random.seed(seed) X = np.random.rand(M, K).astype(np.float32) - 0.5 @@ -41,6 +44,9 @@ class PackedFCTest(hu.HypothesisTestCase): ) self.assertReferenceChecks(gc, op, [X, W, b], ref) + @unittest.skipIf(not core.C.builtin_cpu_supports_avx2(), + "Intel MKL sgemm_pack has a known numerical issue with " + "non-avx2 machines that will be fixed in a later build.") @given(axis=st.integers(min_value=1, max_value=4), num_output=st.integers(min_value=4, max_value=8), **hu.gcs_cpu_only) @@ -59,6 +65,8 @@ class PackedFCTest(hu.HypothesisTestCase): axis=axis) def ref(X, W, b): - return (np.dot(X.reshape(X.size / K, K), W.T) + b,) + output_axes = list(X.shape[:axis]) + [N] + return ( + np.dot(X.reshape(X.size / K, K), W.T).reshape(output_axes) + b,) self.assertReferenceChecks(gc, op, [X, W, b], ref) diff --git a/caffe2/python/operator_test/pack_ops_test.py b/caffe2/python/operator_test/pack_ops_test.py index 509da87a40af..67243b0342a0 100644 --- a/caffe2/python/operator_test/pack_ops_test.py +++ b/caffe2/python/operator_test/pack_ops_test.py @@ -41,3 +41,28 @@ class TestTensorPackOps(TestCase): workspace.RunOperatorOnce(core.CreateOperator( 'UnpackSegments', ['l', 't'], ['newd'])) assert((workspace.FetchBlob('newd') == workspace.FetchBlob('d')).all()) + + def test_pad_minf(self): + workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int32)) + workspace.FeedBlob( + 'd', + np.array([ + [1.0, 1.0], + [2.0, 2.0], + [2.0, 2.0], + [3.0, 3.0], + [3.0, 3.0], + [3.0, 3.0]], + dtype=np.float32)) + workspace.RunOperatorOnce(core.CreateOperator( + 'PackSegments', ['l', 'd'], ['t'], pad_minf=True)) + workspace.RunOperatorOnce(core.CreateOperator( + 'Exp', ['t'], ['r'] + )) + result = workspace.FetchBlob('t') + assert(result[0, -1, 0] < -1000.0) + + # The whole point of padding with -inf is that when we exponentiate it + # then it should be zero. + exponentiated = workspace.FetchBlob('r') + assert(exponentiated[0, -1, 0] == 0.0) diff --git a/caffe2/python/operator_test/pooling_test.py b/caffe2/python/operator_test/pooling_test.py index 8b7ba5770d23..a3c3890407f0 100644 --- a/caffe2/python/operator_test/pooling_test.py +++ b/caffe2/python/operator_test/pooling_test.py @@ -24,7 +24,7 @@ class TestPooling(hu.HypothesisTestCase): input_channels=st.integers(1, 3), batch_size=st.integers(1, 3), order=st.sampled_from(["NCHW", "NHWC"]), - method=st.sampled_from(["MaxPool", "AveragePool"]), + method=st.sampled_from(["MaxPool", "AveragePool", "LpPool"]), **hu.gcs) def test_pooling_separate_stride_pad(self, stride_h, stride_w, pad_t, pad_l, pad_b, @@ -49,10 +49,11 @@ class TestPooling(hu.HypothesisTestCase): ) X = np.random.rand( batch_size, size, size, input_channels).astype(np.float32) + if order == "NCHW": X = X.transpose((0, 3, 1, 2)) self.assertDeviceChecks(dc, op, [X], [0]) - if method != 'MaxPool': + if method not in ('MaxPool'): self.assertGradientChecks(gc, op, [X], 0, [0]) @given(stride=st.integers(1, 3), @@ -62,7 +63,7 @@ class TestPooling(hu.HypothesisTestCase): input_channels=st.integers(1, 3), batch_size=st.integers(1, 3), order=st.sampled_from(["NCHW", "NHWC"]), - method=st.sampled_from(["MaxPool", "AveragePool"]), + method=st.sampled_from(["MaxPool", "AveragePool", "LpPool"]), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) def test_pooling(self, stride, pad, kernel, size, @@ -85,14 +86,14 @@ class TestPooling(hu.HypothesisTestCase): X = X.transpose((0, 3, 1, 2)) self.assertDeviceChecks(dc, op, [X], [0]) - if method != 'MaxPool': + if method not in ('MaxPool'): self.assertGradientChecks(gc, op, [X], 0, [0]) @given(size=st.integers(7, 9), input_channels=st.integers(1, 3), batch_size=st.integers(1, 3), order=st.sampled_from(["NCHW", "NHWC"]), - method=st.sampled_from(["MaxPool", "AveragePool"]), + method=st.sampled_from(["MaxPool", "AveragePool", "LpPool"]), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) def test_global_pooling(self, size, input_channels, batch_size, @@ -111,5 +112,5 @@ class TestPooling(hu.HypothesisTestCase): X = X.transpose((0, 3, 1, 2)) self.assertDeviceChecks(dc, op, [X], [0]) - if method != 'MaxPool': + if method not in ('MaxPool'): self.assertGradientChecks(gc, op, [X], 0, [0]) diff --git a/caffe2/python/operator_test/python_op_test.py b/caffe2/python/operator_test/python_op_test.py new file mode 100644 index 000000000000..25063943b264 --- /dev/null +++ b/caffe2/python/operator_test/python_op_test.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from caffe2.python import core, workspace +from caffe2.python.core import CreatePythonOperator +import caffe2.python.hypothesis_test_util as hu +from hypothesis import given +import hypothesis.strategies as st +import numpy as np +import unittest + +try: + import numba + HAS_NUMBA = True +except ImportError: + HAS_NUMBA = False + + +class PythonOpTest(hu.HypothesisTestCase): + @unittest.skipIf(not HAS_NUMBA, "") + @given(x=hu.tensor(), + n=st.integers(min_value=1, max_value=20), + w=st.integers(min_value=1, max_value=20)) + def test_multithreaded_evaluation_numba_nogil(self, x, n, w): + @numba.jit(nopython=True, nogil=True) + def g(input_, output): + output[...] = input_ + + def f(inputs, outputs): + outputs[0].reshape(inputs[0].shape) + g(inputs[0].data, outputs[0].data) + + ops = [CreatePythonOperator(f, ["x"], [str(i)]) for i in range(n)] + net = core.Net("net") + net.Proto().op.extend(ops) + net.Proto().type = "dag" + net.Proto().num_workers = w + iters = 100 + plan = core.Plan("plan") + plan.AddStep(core.ExecutionStep("test-step", net, iters)) + workspace.FeedBlob("x", x) + workspace.RunPlan(plan.Proto().SerializeToString()) + for i in range(n): + y = workspace.FetchBlob(str(i)) + np.testing.assert_almost_equal(x, y) diff --git a/caffe2/python/operator_test/reshape_ops_test.py b/caffe2/python/operator_test/reshape_ops_test.py index f9e2d93bc75d..c9b7dee33854 100644 --- a/caffe2/python/operator_test/reshape_ops_test.py +++ b/caffe2/python/operator_test/reshape_ops_test.py @@ -38,6 +38,26 @@ class TestLengthsToShapeOps(TestCase): test_reshape(old_shape=(4, 2, 1), new_shape=(-1, 8), in_place=True, arg_shape=False) + def test_zero_dim(self): + test_reshape(old_shape=(4, 2, 1), new_shape=(0, 0, 0), + expected_shape=(4, 2, 1)) + test_reshape(old_shape=(4, 2, 1), new_shape=(0, 0, 0), + expected_shape=(4, 2, 1), arg_shape=False) + test_reshape(old_shape=(4, 2, 1), new_shape=(0, 2, 1), + expected_shape=(4, 2, 1)) + test_reshape(old_shape=(4, 2, 1), new_shape=(0, 2, 1), + expected_shape=(4, 2, 1), arg_shape=False) + + def test_zero_dim_and_missing_dim(self): + test_reshape(old_shape=(4, 2, 1), new_shape=(0, -1, 0), + expected_shape=(4, 2, 1)) + test_reshape(old_shape=(4, 2, 1), new_shape=(0, -1, 0), + expected_shape=(4, 2, 1), arg_shape=False) + test_reshape(old_shape=(4, 3, 2), new_shape=(-1, 0), + expected_shape=(8, 3)) + test_reshape(old_shape=(4, 3, 2), new_shape=(-1, 0), + expected_shape=(8, 3), arg_shape=False) + def test_backprop(self): old_shape = (4, 2, 1) new_shape = (1, 8) @@ -84,7 +104,10 @@ class TestLengthsToShapeOps(TestCase): workspace.RunNet(net) -def test_reshape(old_shape, new_shape, arg_shape=True, in_place=False): +def test_reshape(old_shape, new_shape, expected_shape=None, arg_shape=True, + in_place=False): + if expected_shape is None: + expected_shape = new_shape X = np.random.rand(*old_shape).astype(np.float32) blob_in = 'X' @@ -105,4 +128,4 @@ def test_reshape(old_shape, new_shape, arg_shape=True, in_place=False): workspace.RunOperatorOnce(op) Y = workspace.FetchBlob(blob_out) - np.testing.assert_allclose(Y, X.reshape(new_shape)) + np.testing.assert_allclose(Y, X.reshape(expected_shape)) diff --git a/caffe2/python/operator_test/softmax_ops_test.py b/caffe2/python/operator_test/softmax_ops_test.py new file mode 100644 index 000000000000..4b8fc5c4059a --- /dev/null +++ b/caffe2/python/operator_test/softmax_ops_test.py @@ -0,0 +1,266 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from caffe2.python import core, workspace +from hypothesis import given +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np + +import unittest + + +class TestSoftmaxOps(hu.HypothesisTestCase): + + @given(n=st.integers(2, 10), D=st.integers(4, 16), **hu.gcs) + def test_softmax(self, n, D, gc, dc): + # n = number of examples, D = |labels| + # Initialize X and add 1e-2 for numerical stability + X = np.random.rand(n, D).astype(np.float32) + X = X + 1e-2 + + # Reference implementation of cross entropy with soft labels + def label_softmax(X): + probs = np.zeros((n, D)) + rowmax = np.zeros(n) + for i in range(n): + rowmax[i] = max(X[i, ]) + # We need to subtract the max to avoid numerical issues + probs[i] = X[i] - rowmax[i] + exps = np.exp(probs[i, ]) + norm = sum(exps) + probs[i, ] = exps / norm + + return [probs] + + op = core.CreateOperator( + "Softmax", + ["X"], + ["probs"] + ) + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=[X], + reference=label_softmax, + ) + + self.assertGradientChecks( + gc, op, [X], 0, [0], stepsize=1e-4, threshold=1e-2) + + @given(axis=st.integers(min_value=1, max_value=4), **hu.gcs) + def test_softmax_axis(self, axis, gc, dc): + np.random.seed(1) + X = np.random.randn(1, 2, 3, 2, 1).astype(np.float32) + X = X + 1e-2 + + def prod(xs): + p = 1 + for x in xs: + p *= x + return p + + N = prod(list(X.shape)[:axis]) + D = prod(list(X.shape)[axis:]) + + # Reference implementation of cross entropy with soft labels + def label_softmax(X): + X_ = X.reshape(N, D) + probs = np.zeros((N, D)) + rowmax = np.zeros(N) + for i in range(N): + rowmax[i] = max(X_[i, ]) + # We need to subtract the max to avoid numerical issues + probs[i] = X_[i] - rowmax[i] + exps = np.exp(probs[i, ]) + norm = sum(exps) + probs[i, ] = exps / norm + + return [probs.reshape(*X.shape)] + + op = core.CreateOperator( + "Softmax", + ["X"], + ["probs"], + axis=axis, + ) + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=[X], + reference=label_softmax, + ) + + self.assertGradientChecks( + gc, op, [X], 0, [0], stepsize=1e-4, threshold=1e-2) + + @given(n=st.integers(2, 10), D=st.integers(4, 16), **hu.gcs) + def test_softmax_with_loss(self, n, D, gc, dc): + # n = number of examples, D = |labels| + # Initialize X and add 1e-2 for numerical stability + X = np.random.rand(n, D).astype(np.float32) + X = X + 1e-2 + + # Initialize label + label = (np.random.rand(n) * D).astype(np.int32) + + # Reference implementation of cross entropy with soft labels + def label_softmax_crossent(X, label): + probs = np.zeros((n, D)) + rowmax = np.zeros(n) + for i in range(n): + rowmax[i] = max(X[i, ]) + # We need to subtract the max to avoid numerical issues + probs[i] = X[i] - rowmax[i] + exps = np.exp(probs[i, ]) + norm = sum(exps) + probs[i, ] = exps / norm + + label_xent = [-np.log(max(probs[i][label[i]], 1e-20)) + for i in range(n)] + avgloss = np.sum(label_xent) / float(n) + return (probs, avgloss) + + op = core.CreateOperator( + "SoftmaxWithLoss", + ["X", "label"], + ["probs", "avgloss"] + ) + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=[X, label], + reference=label_softmax_crossent, + ) + + self.assertGradientChecks( + gc, op, [X, label], 0, [1], stepsize=1e-4, threshold=1e-2) + + @unittest.skipIf(not workspace.has_gpu_support, "No gpu support") + @given(n=st.integers(2, 5), D=st.integers(2, 4), + weighted=st.booleans(), **hu.gcs_gpu_only) + def test_spatial_softmax_with_loss(self, n, D, weighted, gc, dc): + # n = number of examples, D = |labels| + # Initialize X and add 1e-2 for numerical stability + W = 18 + H = 12 + X = np.random.rand(n, D, H, W).astype(np.float32) + X = X + 1e-2 + + weighted = True + weights = None + if weighted: + weights = np.random.rand(n, H, W).astype(np.float32) + + # Initialize label. Some of the labels are (-1), i.e "DONT CARE" + label = (np.random.rand(n, H, W) * (D + 1)).astype(np.int32) - 1 + + def label_softmax_crossent_spatial(X, label, weights=None): + probs = np.zeros((n, D, H, W)) + rowmax = np.zeros((n, H, W)) + label_xent = np.zeros((n, H, W)) + for i in range(n): + for x in range(W): + for y in range(H): + rowmax[i, y, x] = max(X[i, :, y, x]) + # We need to subtract the max to avoid numerical issues + probs[i, :, y, x] = X[i, :, y, x] - rowmax[i, y, x] + exps = np.exp(probs[i, :, y, x]) + probs[i, :, y, x] = exps / sum(exps) + + label_xent[:, y, x] = \ + [-np.log(max(probs[j, label[i, y, x], y, x], 1e-20)) + for j in range(n)] + + total_xent = 0.0 + total_weight = 0.0 + for y in range(H): + for x in range(W): + for i in range(n): + l = label[i, y, x] + if (l != (-1)): + w = 1.0 if weights is None else weights[i, y, x] + total_xent += \ + -np.log(max(probs[i, l, y, x], 1e-20)) * w + total_weight += w + print("Total weight {}".format(total_weight)) + + return (probs, total_xent / total_weight) + + op = core.CreateOperator( + "SoftmaxWithLoss", + ["X", "label"] + ([] if weights is None else ["weights"]), + ["probs", "avgloss"], + spatial=1 + ) + + inputs = [X, label] + ([] if weights is None else [weights]) + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=inputs, + reference=label_softmax_crossent_spatial, + ) + + self.assertGradientChecks( + gc, op, inputs, 0, [1], stepsize=1e-4, threshold=1e-2) + + @unittest.skipIf(not workspace.has_gpu_support, "No gpu support") + def test_compare_cpugpu(self): + ''' + Additional test that checks CPU and GPU returns same values + with larger examples. This is mainly to test the more complex + GPU implementation is correct. + ''' + from caffe2.proto import caffe2_pb2 + + for j in range(3): + gpuop = core.CreateOperator( + "SoftmaxWithLoss", + ["X_gpu", "label_gpu"], + ["probs_gpu", "avgloss_gpu"], + spatial=1, + device_option=core.DeviceOption(caffe2_pb2.CUDA, 0) + ) + + cpuop = core.CreateOperator( + "SoftmaxWithLoss", + ["X_cpu", "label_cpu"], + ["probs_cpu", "avgloss_cpu"], + spatial=1, + device_option=core.DeviceOption(caffe2_pb2.CPU) + ) + + n = 8 + D = 4 + W = 64 + int(np.random.rand(1) * 1024) + H = 64 + int(np.random.rand(1) * 1024) + + print("W: {} H: {}".format(W, H)) + + X = np.random.rand(n, D, H, W).astype(np.float32) + X = X + 1e-2 + + # Initialize label. Some of the labels are (-1), i.e "DONT CARE" + label = (np.random.rand(n, H, W) * (D + 1)).astype(np.int32) - 1 + + gpu0 = core.DeviceOption(caffe2_pb2.CUDA, 0) + workspace.FeedBlob("X_cpu", X) + workspace.FeedBlob("label_cpu", label) + workspace.FeedBlob("X_gpu", X, device_option=gpu0) + workspace.FeedBlob("label_gpu", label, device_option=gpu0) + + workspace.RunOperatorOnce(gpuop) + workspace.RunOperatorOnce(cpuop) + + probs_gpu = workspace.FetchBlob("probs_gpu") + probs_cpu = workspace.FetchBlob("probs_cpu") + loss_gpu = workspace.FetchBlob("avgloss_gpu") + loss_cpu = workspace.FetchBlob("avgloss_cpu") + + np.testing.assert_allclose(probs_gpu, probs_cpu, rtol=1e-4) + np.testing.assert_allclose(loss_gpu, loss_cpu, rtol=1e-1) diff --git a/caffe2/python/pipeline.py b/caffe2/python/pipeline.py index 1f8004a9951a..be362a4a7803 100644 --- a/caffe2/python/pipeline.py +++ b/caffe2/python/pipeline.py @@ -5,41 +5,191 @@ from __future__ import unicode_literals from caffe2.python import core, queue_util from caffe2.python.dataio import Reader, Writer +from caffe2.python.net_builder import NetBuilder +from caffe2.python.schema import as_record, Field +from caffe2.python.task import Task, TaskGroup -def processor_step( - reader, writer, num_threads=1, processor=None, name='processor'): +class Output(object): """ - Given a reader and a writer, couple them through a processor, running - across multiple threads. + Represents the result of a processor function. A processor can either + return an Output, or it can return a record, in which case an Output will be + created for it afterwards. + """ + def __init__(self, nets=None, record=None, should_stop=None): + builder_children = NetBuilder.current().get() + assert nets is None or len(builder_children) == 0, ( + 'Cannot both use `ops` syntax and return a list of nets.') + if nets is None: + nets = builder_children + if isinstance(nets, core.Net): + nets = [nets] + self.nets = [] if nets is None else list(nets) + self.record = None if record is None else as_record(record) + self.should_stop = should_stop + + +DEFAULT_QUEUE_CAPACITY = 10 + + +def _init_output(output, capacity, global_init_net, global_exit_net): + if isinstance(output, Writer): + assert capacity is None, 'capacity would not be used.' + out_queue = None + writer = output + elif hasattr(output, 'writer'): + assert capacity is None, 'capacity would not be used.' + out_queue = output + writer = output.writer() + elif output is None: + out_queue = queue_util.Queue( + capacity=( + capacity if capacity is not None + else DEFAULT_QUEUE_CAPACITY)) + writer = out_queue.writer() + else: + raise ValueError('output must be a reader, queue or stream.') + writer.setup_ex(global_init_net, global_exit_net) + return out_queue, writer + + +def make_processor(processor): + if processor is None: + return lambda rec: rec + elif isinstance(processor, core.Net): + return NetProcessor(processor) + else: + return processor + + +def normalize_processor_output(output): + """ + Allow for processors to return results in several formats. + TODO(azzolini): simplify once all processors use NetBuilder API. + """ + if isinstance(output, Output): + """ Processor returned an Output. """ + return output + elif isinstance(output, Field): + """ Processor returned a record. """ + return Output(record=output) + elif isinstance(output, tuple): + is_record_and_blob = ( + len(output) == 2 and + isinstance(output[0], Field) and + isinstance(output[1], core.BlobReference)) + if is_record_and_blob: + """ Processor returned (record, stop_blob) """ + return Output(None, *output) + else: + """ Processor returned (nets, record, stop_blob) """ + return Output(*output) + else: + """ Processor returned nets, no output """ + return Output(output) + + +def pipe( + input, output=None, num_threads=1, processor=None, name=None, + capacity=None, group=None): + """ + Given a Reader, Queue or DataStream in `input`, and optionally, a Writer, + Queue or DataStream in `output`, creates a Task that, when run, will + pipe the input into the output, using multiple parallel threads. + Additionally, if a processor is given, it will be called between reading + and writing steps, allowing it to transform the record. Args: - reader: an instance of dataio.Reader - writer: an instance of dataio.Wrier - num_threads: number of processing threads - processor: if provided, a function taking form: - (nets, out_record) = processor(record) - where `record` is a schema.Struct containing the input, - `nets` is the list of nets doing the transformation, and - `out_record` is a schema.Struct with transformed data; - name: Name to be given to nets and execution steps created. + input: either a Reader, Queue or DataStream that will be read + until a stop is signaled either by the reader or the + writer. + output: either a Writer, a Queue or a DataStream that will be + writen to as long as neither reader or writer signal + a stop condition. If output is not provided or is None, + a Queue is created with given `capacity` and writen to. + num_threads: number of concurrent threads used for processing and + piping. If set to 0, no Task is created, and a + reader is returned instead -- the reader returned will + read from the reader passed in and process it. + processor: (optional) function that takes an input record and + optionally returns a record; this will be called + between read and write steps. If the processor does + not return a record, a writer will not be instantiated. + Processor can also be a core.Net with input and output + records properly set. In that case, a NetProcessor is + instantiated, cloning the net for each of the threads. + name: (optional) name of the task to be created. + capacity: when output is not passed, a queue of given `capacity` + is created and written to. + group: (optional) explicitly add the created Task to this + TaskGroup, instead of using the currently active one. Returns: - Execution step that runs all threads of the processor in parallel. + Output Queue, DataStream, Reader, or None, depending on the parameters + passed. """ - assert isinstance(reader, Reader) - assert isinstance(writer, Writer) - global_init_net = core.Net(name + '_producer_global_init') + result, step = _pipe_step( + input, output, num_threads, processor, name, capacity, group) + if step is not None: + Task(step=step, group=group) + return result + + +def pipe_and_output( + input, output=None, num_threads=1, processor=None, name=None, + capacity=None, group=None, final_outputs=None): + """ + Similar to `pipe`, with the additional ability for the pipe Task to + return output values to the `Session` once done. + + Returns: + Tuple (out_queue, *task_outputs) + out_queue: same as return value of `pipe`. + task_outputs: TaskOutput object, fetchable from the client after + session.run() returns. + """ + result, step = _pipe_step( + input, output, num_threads, processor, name, capacity, group, + final_outputs) + assert step is not None + task = Task(step=step, group=group, outputs=final_outputs) + output = None + if final_outputs is not None: + output = task.outputs() + if type(final_outputs) not in (list, tuple): + output = output[0] + return result, output + + +def _pipe_step( + input, output=None, num_threads=1, processor=None, name=None, + capacity=None, group=None, final_outputs=None): + """ + """ + group = TaskGroup.current(group) + if name is None: + name = 'processor:%d' % group.num_registered_tasks() + + if isinstance(input, Reader): + reader = input + elif hasattr(input, 'reader'): + reader = input.reader() + else: + raise ValueError('in must be a reader, queue or streaam.') + + if processor is not None: + reader = ProcessingReader(reader, processor) + + if num_threads == 0: + assert output is None + return reader, None + global_exit_net = core.Net(name + '_producer_global_exit') + global_init_net = core.Net(name + '_producer_global_init') + out_queue = None + writer = None reader.setup_ex(global_init_net, global_exit_net) - writer.setup_ex(global_init_net, global_exit_net) - - def default_processor(fields): - return [], fields - - if processor is None: - processor = default_processor steps = [] for thread_id in range(num_threads): @@ -47,83 +197,119 @@ def processor_step( exit_net = core.Net(name + "_exit_net_%d" % thread_id) read_nets, status, rec = reader.read_record_ex(init_net, exit_net) - process_nets, rec = processor(rec) - write_nets, _ = writer.write_record_ex(rec, init_net, exit_net, status) + + if rec is not None: + if writer is None: + out_queue, writer = _init_output( + output, capacity, global_init_net, global_exit_net) + write_nets, _ = writer.write_record_ex( + rec, init_net, exit_net, status) + else: + write_nets = [] step = core.execution_step( name + "_thread_%d" % thread_id, [ core.execution_step(name + "_init_step", init_net), core.execution_step( name + "_worker_step", - list(read_nets) + list(process_nets) + list(write_nets), + list(read_nets) + list(write_nets), should_stop_blob=status ), core.execution_step(name + "_exit_step", exit_net) ] ) steps.append(step) - - return core.execution_step( + step = core.execution_step( "sender_step", [ core.execution_step('init_step', global_init_net), core.execution_step( "sender_steps", steps, concurrent_substeps=True), core.execution_step('finish_step', global_exit_net), - ] - ) - - -class LocalPipeline(object): - """ - Create a data processing pipeline consisting of a sequence of - multi-threaded processors communicating through queues. - """ - def __init__(self): - self.tasks = [] - self.init_net = core.Net('worker_init') - - def create_queue(self, capacity, schema): - """ - Create a queue that will be used to communicate between processors. - - Args: - capacity: max number of records in the queue - schema: a schema.Struct representing the schema of a record in - the queue. - - Returns: - A QueueWrapper containing a queue. - """ - return queue_util.QueueWrapper(self.init_net, capacity, schema) - - def add_task(self, task): - """ - Add a task to the pipeline. - This task will run in parallel to other tasks in the pipeline. - """ - self.tasks.append(task) - - def link(self, reader, writer, num_threads=1, processor=None): - """ - Add a task that will read from `reader`, and write to `writer`. - See function `processor_step` above for description of the arguments. - """ - self.add_task(processor_step(reader, writer, num_threads, processor)) - - def get_step(self): - """ - Create and return a Caffe2 execution step that will run all the tasks - of this pipeline in parallel. - """ - return core.execution_step('worker_step', [ - core.execution_step('worker_init', self.init_net), - core.execution_step( - 'tasks_step', self.tasks, concurrent_substeps=True) ]) + return out_queue, step - def get_step_and_output(self): - """ - Return a tuple (execution_step, output) to be used as one of the tasks - in a distributed pipeline. - """ - output = self.init_net.ConstantFill([], value=0.0) - return self.get_step(), [output] + +class ProcessingReader(Reader): + """ + Reader that reads from a upstream reader, calls the processor, and returns + the processed record. + """ + def __init__(self, reader, processor): + Reader.__init__(self) + self.reader = reader + self.processor = make_processor(processor) + + def setup_ex(self, init_net, finish_net): + self.reader.setup_ex(init_net, finish_net) + + def read_ex(self, init_net, exit_net): + read_nets, status, rec = self.reader.read_record_ex(init_net, exit_net) + with NetBuilder(): + # Current NetBuilder is optionally used inside the processor, + # then its children are retrived inside of + # normalize_processor_output. + # Once readers and writers also use NetBuilder, + # this logic will be more natural. + result = normalize_processor_output(self.processor(rec)) + read_nets += result.nets + if result.should_stop is not None: + stop_net = core.Net('stop_net') + stop_net.Copy([result.should_stop], [status]) + read_nets.append(stop_net) + if hasattr(self.processor, 'setup'): + init_net.add_attribute(TaskGroup.LOCAL_SETUP, self.processor) + self._set_schema(result.record) + fields = result.record.field_blobs() if result.record else None + return read_nets, status, fields + + +class NetProcessor(object): + """ + Processor that clones a core.Net each time it's called, executing + the cloned net as the processor. It requires the Net to have input + and (optionally) output records set, with net.set_input_record() and + net.set_output_record(). + """ + def __init__(self, net, stop_signal=None, thread_init_nets=None): + assert isinstance(net, core.Net) + assert stop_signal is None or isinstance( + stop_signal, core.BlobReference) + self.thread_init_nets = thread_init_nets or [] + self.net = net + self._stop_signal = stop_signal + self._blob_maps = [] + self._frozen = False + self._cloned_init_nets = [] + + def setup(self, init_net): + self._frozen = True + cloned_init_nets = self._cloned_init_nets + self._cloned_init_nets = [] + return cloned_init_nets + + def __call__(self, rec): + assert not self._frozen + prefix = '/worker:%d/' % len(self._blob_maps) + blob_remap = {} + for net in self.thread_init_nets: + new_net, _ = core.clone_and_bind_net( + net, str(net) + prefix, prefix, blob_remap) + self._cloned_init_nets.append(new_net) + + new_net, remappings = core.clone_and_bind_net( + self.net, str(self.net) + prefix, prefix, blob_remap, rec) + + if self._stop_signal is None: + stop_signal = None + elif str(self._stop_signal) in remappings: + stop_signal = core.BlobReference( + remappings[str(self._stop_signal)], + net=new_net) + else: + stop_signal = self._stop_signal + + self._blob_maps.append(remappings) + return Output([new_net], new_net.output_record(), stop_signal) + + def blob_maps(self): + self._frozen = True + return self._blob_maps diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 757d6bf90a50..5fa1fa8f93c3 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -3,10 +3,12 @@ #include #include +#include "caffe2/core/asan.h" #include "caffe2/core/db.h" #include "caffe2/core/predictor.h" namespace caffe2 { +namespace python { namespace py = pybind11; @@ -102,6 +104,57 @@ void switchWorkspaceInternal(const std::string& name, bool create_if_missing) { gCurrentWorkspaceName = name; } +FuncRegistery& gRegistery() { + // Always leak the objects registered here. + static FuncRegistery* r = new FuncRegistery(); + return *r; +} + +py::object& getOpFunc(const std::string& token) { + CAFFE_ENFORCE( + gRegistery().count(token), + "Python operator for ", + token, + " is not available. If you use distributed training it probably means " + "that python implementation has to be registered in each of the workers"); + return gRegistery()[token]; +} + +py::object& getGradientFunc(const std::string& token) { + return getOpFunc(token + "_gradient"); +} + +struct GetPythonGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + std::vector GetGradientDefs() override { + std::vector gradientInputs; + for (int i = 0; i < def_.input_size(); ++i) { + gradientInputs.push_back(I(i)); + } + for (int i = 0; i < def_.output_size(); ++i) { + gradientInputs.push_back(O(i)); + } + for (int i = 0; i < def_.output_size(); ++i) { + gradientInputs.push_back(GO(i)); + } + std::vector gradientOutputs; + for (int i = 0; i < def_.input_size(); ++i) { + gradientOutputs.push_back(GI(i)); + } + + return SingleGradientDef( + "PythonGradient", "", gradientInputs, gradientOutputs); + } +}; + +REGISTER_CPU_OPERATOR(Python, PythonOp); +REGISTER_CPU_OPERATOR(PythonGradient, PythonGradientOp); +// Always allow running in-place +OPERATOR_SCHEMA(Python).AllowInplace([](int, int) { return true; }); +OPERATOR_SCHEMA(PythonGradient).AllowInplace([](int, int) { return true; }); + +REGISTER_GRADIENT(Python, GetPythonGradient); + void addObjectMethods(py::module& m) { py::class_(m, "Net").def("run", [](NetBase* net) { py::gil_scoped_release g; @@ -162,8 +215,60 @@ void addObjectMethods(py::module& m) { py::arg("arg"), py::arg("device_option") = py::none()); + py::class_(m, "TensorCPU") + .def_property_readonly( + "data", + [](TensorCPU* t) -> py::object { + CAFFE_ENFORCE(t->size() > 0); + std::vector npy_dims; + for (const auto dim : t->dims()) { + npy_dims.push_back(dim); + } + // TODO: use float as default data type if it's a new Tensor. + // consider to support setting data type + TypeMeta meta = t->meta(); + if (meta.id() == 0) { + meta = TypeMeta::Make(); + } + auto numpy_type = CaffeToNumpyType(meta); + if (numpy_type == NPY_OBJECT) { + PyObject* array = + PyArray_SimpleNew(t->ndim(), npy_dims.data(), numpy_type); + void* outPtr = static_cast( + PyArray_DATA(reinterpret_cast(array))); + PyObject** outObj = reinterpret_cast(outPtr); + auto* str = t->template mutable_data(); + for (TIndex i = 0; i < t->size(); ++i) { + outObj[i] = PyBytes_FromStringAndSize(str->data(), str->size()); + str++; + // cleanup on failure + if (outObj[i] == nullptr) { + for (TIndex j = 0; j < i; ++j) { + Py_DECREF(outObj[j]); + } + Py_DECREF(array); + CAFFE_THROW( + "Failed to allocate string for ndarray of strings."); + } + } + return pybind11::object(array, /* borrowed= */ false); + } + PyObject* array = PyArray_SimpleNewFromData( + t->ndim(), + npy_dims.data(), + numpy_type, + t->raw_mutable_data(meta)); + return py::object(array, /* borrowed= */ false); + }) + .def_property_readonly( + "_shape", [](const TensorCPU& t) { return t.dims(); }) + .def("_reshape", [](TensorCPU* t, std::vector dims) { + t->Resize(dims); + }); + py::class_(m, "Workspace") .def(py::init<>()) + .def(py::init()) .def_property_readonly( "nets", [](Workspace* self) { @@ -266,11 +371,11 @@ void addObjectMethods(py::module& m) { .def("put", &db::Transaction::Put) .def("commit", &db::Transaction::Commit); py::class_(m, "Cursor") - .def("supports_seak", &db::Cursor::SupportsSeek) + .def("supports_seek", &db::Cursor::SupportsSeek) .def("seek_to_first", &db::Cursor::SeekToFirst) .def("next", &db::Cursor::Next) - .def("key", &db::Cursor::key) - .def("value", &db::Cursor::value) + .def("key", [](db::Cursor* self) -> py::bytes { return self->key(); }) + .def("value", [](db::Cursor* self) -> py::bytes { return self->value(); }) .def("valid", &db::Cursor::Valid); py::enum_(m, "Mode") .value("read", db::Mode::READ) @@ -322,6 +427,8 @@ void addObjectMethods(py::module& m) { } void addGlobalMethods(py::module& m) { + m.attr("is_asan") = py::bool_(CAFFE2_ASAN_ENABLED); + m.def("global_init", [](std::vector args) -> void { int argc = args.size(); std::vector argv; @@ -351,9 +458,14 @@ void addGlobalMethods(py::module& m) { return keys; }); m.def("on_module_exit", []() { gWorkspaces.clear(); }); + // create_if_missing not used by necessary for pybind to do + // properly do function overloading. + m.def("switch_workspace", [](Workspace* ws, py::object create_if_missing) { + gWorkspace = ws; + }); m.def( "switch_workspace", - [](const std::string& name, py::object create_if_missing) { + [](const std::string& name, const py::object create_if_missing) { if (create_if_missing == py::none()) { return switchWorkspaceInternal(name, false); } @@ -390,6 +502,10 @@ void addGlobalMethods(py::module& m) { } return names; }); + m.def("local_blobs", []() { + CAFFE_ENFORCE(gWorkspace); + return gWorkspace->LocalBlobs(); + }); m.def("blobs", []() { CAFFE_ENFORCE(gWorkspace); return gWorkspace->Blobs(); @@ -400,15 +516,21 @@ void addGlobalMethods(py::module& m) { }); m.def("create_net", [](py::bytes net_def) { caffe2::NetDef proto; - CAFFE_ENFORCE(proto.ParseFromString(net_def)); - CAFFE_ENFORCE(gWorkspace->CreateNet(proto)); + CAFFE_ENFORCE( + proto.ParseFromString(net_def), + "Can't parse net proto: ", + std::string(net_def)); + CAFFE_ENFORCE( + gWorkspace->CreateNet(proto), + "Error creating net with proto: ", + std::string(net_def)); return true; }); m.def("run_net", [](const std::string& name) { CAFFE_ENFORCE(gWorkspace); - CAFFE_ENFORCE(gWorkspace->GetNet(name)); + CAFFE_ENFORCE(gWorkspace->GetNet(name), "Can't find net ", name); py::gil_scoped_release g; - CAFFE_ENFORCE(gWorkspace->RunNet(name)); + CAFFE_ENFORCE(gWorkspace->RunNet(name), "Error running net ", name); return true; }); m.def( @@ -521,6 +643,33 @@ void addGlobalMethods(py::module& m) { CAFFE_ENFORCE(blob->Deserialize(serialized.cast())); }); + m.def("register_python_op", [](py::object func) { + CAFFE_ENFORCE(func != py::none()); + const std::string name = func.attr("__name__").cast(); + // Unique name since registry is never cleared. + const std::string token = name + to_string(gRegistery().size()); + CAFFE_ENFORCE(gRegistery().find(name) == gRegistery().end()); + gRegistery()[token] = func; + return token; + }); + + m.def( + "register_python_gradient_op", + [](const std::string& token, py::object func) { + CAFFE_ENFORCE(func != py::none()); + CAFFE_ENFORCE(gRegistery().find(token) != gRegistery().end()); + gRegistery()[token + "_gradient"] = func; + }); + +#define CAFFE2_CPU_FEATURE_SUPPORT(feature) \ + m.def("builtin_cpu_supports_" #feature, []() { \ + return __builtin_cpu_supports(#feature); \ + }) + + CAFFE2_CPU_FEATURE_SUPPORT(avx2); + +#undef CAFFE2_CPU_FEATURE_SUPPORT + auto initialize = [&]() { // Initialization of the module ([]() { @@ -552,4 +701,6 @@ PYBIND11_PLUGIN(caffe2_pybind11_state) { addObjectMethods(m); return m.ptr(); } -} + +} // namespace python +} // namespace caffe2 diff --git a/caffe2/python/pybind_state.h b/caffe2/python/pybind_state.h index 7f60efb73b93..8d4014621582 100644 --- a/caffe2/python/pybind_state.h +++ b/caffe2/python/pybind_state.h @@ -1,23 +1,37 @@ #pragma once +#include + #include "caffe2/core/context.h" #include "caffe2/core/init.h" #include "caffe2/core/logging.h" #include "caffe2/core/net.h" #include "caffe2/core/operator.h" #include "caffe2/core/scope_guard.h" +#include "caffe2/core/tensor.h" #include "caffe2/core/types.h" #include "caffe2/core/workspace.h" #include "caffe2/proto/caffe2.pb.h" #include +#include #include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_ARRAY_UNIQUE_SYMBOL caffe2_python_ARRAY_API #include +// Temporary solution for numpy < 1.7 versions: old macro, no promises. +// You're strongly advised to upgrade to >= 1.7. +#ifndef NPY_ARRAY_C_CONTIGUOUS +#define NPY_ARRAY_C_CONTIGUOUS NPY_C_CONTIGUOUS +#define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x)) +#endif + namespace caffe2 { +namespace python { + +namespace py = pybind11; // Add methods common to both CPU and GPU mode. void addGlobalMethods(pybind11::module& m); @@ -161,4 +175,93 @@ class TensorFeeder : public BlobFeederBase { context.FinishDeviceComputation(); } }; -} + +// Python Op implementations. +using FuncRegistery = std::unordered_map; +FuncRegistery& gRegistery(); + +py::object& getOpFunc(const std::string& token); + +py::object& getGradientFunc(const std::string& token); + +class PythonOpBase : public Operator { + public: + using Operator::Operator; + + bool RunOnDevice() final { + std::vector inputs; + inputs.reserve(InputSize()); + for (auto i = 0; i < InputSize(); ++i) { + inputs.push_back(const_cast(&Input(i))); + } + std::vector outputs; + outputs.reserve(OutputSize()); + for (auto i = 0; i < OutputSize(); ++i) { + outputs.push_back(Output(i)); + } + auto& pyFunc = getFunc(); + { + // Acquire GIL for call to Python runtime. + py::gil_scoped_acquire g; + try { + pyFunc(inputs, outputs); + } catch (const py::error_already_set& e) { + LOG(ERROR) << "Exception encountered running PythonOp function: " + << e.what() << "\nTraceback: "; + PyObject *type = nullptr, *value = nullptr, *trace = nullptr; + PyErr_Fetch(&type, &value, &trace); + PyTracebackObject* traceback = + reinterpret_cast(trace); + vector trace_vec; + while (traceback) { + trace_vec.push_back(traceback); + traceback = traceback->tb_next; + } + for (int i = trace_vec.size() - 1; i >= 0; --i) { + int line = trace_vec[i]->tb_lineno; + const char* filename = + PyString_AsString(trace_vec[i]->tb_frame->f_code->co_filename); + const char* funcname = + PyString_AsString(trace_vec[i]->tb_frame->f_code->co_name); + LOG(ERROR) << " # " << trace_vec.size() - i - 1 << " " << filename + << " (" << line << "): " << funcname; + } + Py_XDECREF(type); + Py_XDECREF(value); + Py_XDECREF(trace); + return false; + } + } + return true; + } + + private: + virtual py::object& getFunc() = 0; +}; + +class PythonOp final : public PythonOpBase { + public: + using PythonOpBase::PythonOpBase; + + private: + py::object& getFunc() override { + const std::string& token = + OperatorBase::GetSingleArgument("token", ""); + return getOpFunc(token); + } +}; + +class PythonGradientOp final : public PythonOpBase { + public: + using PythonOpBase::PythonOpBase; + + private: + py::object& getFunc() override { + const std::string& token = + OperatorBase::GetSingleArgument("token", ""); + return getGradientFunc(token); + } +}; + +} // namespace python +} // namespace caffe2 diff --git a/caffe2/python/pybind_state_gpu.cc b/caffe2/python/pybind_state_gpu.cc index 3ffd0a1c39ef..ac99c4e33324 100644 --- a/caffe2/python/pybind_state_gpu.cc +++ b/caffe2/python/pybind_state_gpu.cc @@ -10,8 +10,13 @@ #include #include "caffe2/core/context_gpu.h" +#include "caffe2/operators/operator_fallback_gpu.h" namespace caffe2 { +namespace python { + +REGISTER_CUDA_OPERATOR(Python, GPUFallbackOp); +REGISTER_CUDA_OPERATOR(PythonGradient, GPUFallbackOp); REGISTER_BLOB_FETCHER((TypeMeta::Id()), TensorFetcher); REGISTER_BLOB_FEEDER(CUDA, TensorFeeder); @@ -39,4 +44,5 @@ PYBIND11_PLUGIN(caffe2_pybind11_state_gpu) { addObjectMethods(m); return m.ptr(); } -} +} // namespace python +} // namespace caffe2 diff --git a/caffe2/python/op/python_test.py b/caffe2/python/python_op_test.py similarity index 91% rename from caffe2/python/op/python_test.py rename to caffe2/python/python_op_test.py index b3fa04daec89..35c0addaab57 100644 --- a/caffe2/python/op/python_test.py +++ b/caffe2/python/python_op_test.py @@ -3,7 +3,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals from caffe2.python import core, workspace -from caffe2.python.op.python import CreatePythonOperator +from caffe2.python.core import CreatePythonOperator import caffe2.python.hypothesis_test_util as hu from hypothesis import given import hypothesis.strategies as st @@ -106,8 +106,8 @@ class PythonOpTest(hu.HypothesisTestCase): y = workspace.FetchBlob(str(i)) np.testing.assert_almost_equal(x, y) - @given(x=hu.tensor(), in_place=st.booleans()) - def test_gradient(self, x, in_place): + @given(x=hu.tensor(), in_place=st.booleans(), **hu.gcs) + def test_gradient(self, x, in_place, gc, dc): def f(inputs, outputs): outputs[0].reshape(inputs[0].shape) outputs[0].data[...] = inputs[0].data * 2 @@ -122,10 +122,11 @@ class PythonOpTest(hu.HypothesisTestCase): op = CreatePythonOperator( f, ["x"], ["x" if in_place else "y"], grad_f=grad_f) - self.assertGradientChecks(hu.cpu_do, op, [x], 0, [0]) + self.assertGradientChecks(gc, op, [x], 0, [0]) + self.assertDeviceChecks(dc, op, [x], [0]) - @given(inputs=hu.tensors(n=2)) - def test_gradient_multiple(self, inputs): + @given(inputs=hu.tensors(n=2), **hu.gcs) + def test_gradient_multiple(self, inputs, gc, dc): (x1, x2) = inputs def f(inputs, outputs): @@ -147,4 +148,5 @@ class PythonOpTest(hu.HypothesisTestCase): op = CreatePythonOperator(f, ["x1", "x2"], ["y1", "y2"], grad_f=grad_f) for idx in [0, 1]: - self.assertGradientChecks(hu.cpu_do, op, [x1, x2], idx, [0, 1]) + self.assertGradientChecks(gc, op, [x1, x2], idx, [0, 1]) + self.assertDeviceChecks(dc, op, [x1, x2], [0, 1]) diff --git a/caffe2/python/queue_util.py b/caffe2/python/queue_util.py index a703358d0311..2cae95a365a8 100644 --- a/caffe2/python/queue_util.py +++ b/caffe2/python/queue_util.py @@ -4,63 +4,75 @@ from __future__ import print_function from __future__ import unicode_literals from caffe2.python import core, dataio +from caffe2.python.task import TaskGroup -class QueueReader(dataio.Reader): - def __init__(self, queue, num_blobs=None, schema=None): - dataio.Reader.__init__(self, schema) - assert schema is not None or num_blobs is not None, ( - 'Either schema or num_blobs must be provided.') - - self.queue = queue - self.num_blobs = num_blobs - - if schema is not None: - schema_num_blobs = len(schema.field_names()) - assert num_blobs is None or num_blobs == schema_num_blobs - self.num_blobs = schema_num_blobs +class _QueueReader(dataio.Reader): + def __init__(self, wrapper): + assert wrapper.schema is not None, ( + 'Queue needs a schema in order to be read from.') + dataio.Reader.__init__(self, wrapper.schema()) + self._wrapper = wrapper def setup_ex(self, init_net, exit_net): - exit_net.CloseBlobsQueue([self.queue], 0) + exit_net.CloseBlobsQueue([self._wrapper.queue()], 0) def read_ex(self, local_init_net, local_finish_net): + self._wrapper._new_reader(local_init_net) dequeue_net = core.Net('dequeue_net') - fields, status_blob = dequeue(dequeue_net, self.queue, self.num_blobs) + fields, status_blob = dequeue( + dequeue_net, + self._wrapper.queue(), + len(self.schema().field_names())) return [dequeue_net], status_blob, fields -class QueueWriter(dataio.Writer): - def __init__(self, queue): - self.queue = queue +class _QueueWriter(dataio.Writer): + def __init__(self, wrapper): + self._wrapper = wrapper def setup_ex(self, init_net, exit_net): - exit_net.CloseBlobsQueue([self.queue], 0) + exit_net.CloseBlobsQueue([self._wrapper.queue()], 0) def write_ex(self, fields, local_init_net, local_finish_net, status): + self._wrapper._new_writer(self.schema(), local_init_net) enqueue_net = core.Net('enqueue_net') - enqueue(enqueue_net, self.queue, fields, status) + enqueue(enqueue_net, self._wrapper.queue(), fields, status) return [enqueue_net] -class QueueWrapper(object): - def __init__(self, init_net, capacity, schema): - self._queue = init_net.CreateBlobsQueue( - [], - capacity=capacity, - num_blobs=len(schema.field_names())) - self._schema = schema +class QueueWrapper(dataio.Pipe): + def __init__(self, handler, schema=None): + dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP) + self._queue = handler def reader(self): - return QueueReader(self._queue, schema=self._schema) + return _QueueReader(self) def writer(self): - return QueueWriter(self._queue) + return _QueueWriter(self) def queue(self): return self._queue - def schema(self): - return self._schema + +class Queue(QueueWrapper): + def __init__(self, capacity, schema=None, name='queue'): + # find a unique blob name for the queue + net = core.Net(name) + queue_blob = net.AddExternalInput(net.NextName('handler')) + QueueWrapper.__init__(self, queue_blob, schema) + self.capacity = capacity + self._setup_done = False + + def setup(self, global_init_net): + assert self._schema, 'This queue does not have a schema.' + self._setup_done = True + global_init_net.CreateBlobsQueue( + [], + [self._queue], + capacity=self.capacity, + num_blobs=len(self._schema.field_names())) def enqueue(net, queue, data_blobs, status=None): diff --git a/caffe2/python/schema.py b/caffe2/python/schema.py index 336e74a02145..b66b126899fe 100644 --- a/caffe2/python/schema.py +++ b/caffe2/python/schema.py @@ -21,9 +21,10 @@ import numpy as np from caffe2.python import core from caffe2.python import workspace from caffe2.python.core import BlobReference -from collections import OrderedDict +from collections import OrderedDict, namedtuple logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) def _join_field_name(prefix, suffix): @@ -37,19 +38,42 @@ def _join_field_name(prefix, suffix): return '' -def _normalize_field(field_or_type_or_blob): +def _normalize_field(field_or_type_or_blob, keep_blobs=True): """Clones/normalizes a field before adding it to a container.""" if isinstance(field_or_type_or_blob, Field): - return field_or_type_or_blob.clone() + return field_or_type_or_blob.clone(keep_blobs=keep_blobs) elif type(field_or_type_or_blob) in (type, np.dtype): return Scalar(dtype=field_or_type_or_blob) else: return Scalar(blob=field_or_type_or_blob) +FeatureSpec = namedtuple( + 'FeatureSpec', + ['feature_type', 'feature_names', 'feature_ids']) + + +class Metadata(namedtuple('Metadata', ['categorical_limit', 'expected_value', + 'feature_specs'])): + """Represents additional information associated with a scalar in schema. + + `categorical_limit` - for fields of integral type that are guaranteed to be + non-negative it specifies the maximum possible value plus one. It's often + used as a size of an embedding table. + + `expected_value` - anticipated average value of elements in the field. + Usually makes sense for length fields of lists. + + `feature_specs` - information about the features that contained in this + field. For example if field have more then 1 feature it can have list of + feature names contained in this field.""" + __slots__ = () +Metadata.__new__.__defaults__ = (None, None, None) + class Field(object): """Represents an abstract field type in a dataset. """ + def __init__(self, children): """Derived classes must call this after their initialization.""" self._parent = (None, 0) @@ -60,6 +84,9 @@ class Field(object): offset += len(child.field_names()) self._field_offsets.append(offset) + def clone_schema(self): + return self.clone(keep_blobs=False) + def field_names(self): """Return the children field names for this field.""" raise NotImplementedError('Field is an abstract class.') @@ -68,6 +95,10 @@ class Field(object): """Return the numpy.dtype for each of the children fields.""" raise NotImplementedError('Field is an abstract class.') + def field_metadata(self): + """Return the Metadata for each of the children fields.""" + raise NotImplementedError('Field is an abstract class.') + def field_blobs(self): """Return the list of blobs with contents for this Field. Values can either be all numpy.ndarray or BlobReference. @@ -75,7 +106,16 @@ class Field(object): """ raise NotImplementedError('Field is an abstract class.') - def clone(self): + def all_scalars(self): + """Return the list of all Scalar instances in the Field. + The order is the same as for field_names() or field_blobs()""" + raise NotImplementedError('Field is an abstract class.') + + def has_blobs(self): + """Return True if every scalar of this field has blobs.""" + raise NotImplementedError('Field is an abstract class.') + + def clone(self, keep_blobs=True): """Clone this Field along with its children.""" raise NotImplementedError('Field is an abstract class.') @@ -115,7 +155,9 @@ class Field(object): def __eq__(self, other): """Equivalance of two schemas""" return ((self.field_names() == other.field_names()) and - (self.field_types() == other.field_types())) + (self.field_types() == other.field_types()) and + (self.field_metadata() == other.field_metadata())) + class List(Field): """Represents a variable-length list. @@ -125,6 +167,7 @@ class List(Field): additional `lengths` field, which will contain the size of each list under the parent domain. """ + def __init__(self, values, lengths_blob=None): self.lengths = Scalar(np.int32, lengths_blob) self._items = _normalize_field(values) @@ -141,11 +184,20 @@ class List(Field): def field_types(self): return self.lengths.field_types() + self._items.field_types() + def field_metadata(self): + return self.lengths.field_metadata() + self._items.field_metadata() + def field_blobs(self): return self.lengths.field_blobs() + self._items.field_blobs() - def clone(self): - return List(self._items, self.lengths._blob) + def all_scalars(self): + return self.lengths.all_scalars() + self._items.all_scalars() + + def has_blobs(self): + return self.lengths.has_blobs() and self._items.has_blobs() + + def clone(self, keep_blobs=True): + return List(self._items, self.lengths._blob if keep_blobs else None) def __getattr__(self, item): """If the value of this list is a struct, @@ -163,6 +215,7 @@ class List(Field): class Struct(Field): """Represents a named list of fields sharing the same domain. """ + def __init__(self, *fields): for field in fields: assert len(field) == 2 @@ -190,14 +243,33 @@ class Struct(Field): types += field.field_types() return types + def field_metadata(self): + metadata = [] + for name, field in self.fields.items(): + metadata += field.field_metadata() + return metadata + def field_blobs(self): blobs = [] for name, field in self.fields.items(): blobs += field.field_blobs() return blobs - def clone(self): - return Struct(*self.fields.items()) + def all_scalars(self): + scalars = [] + for name, field in self.fields.items(): + scalars += field.all_scalars() + return scalars + + def has_blobs(self): + return all(field.has_blobs() for field in self.fields.values()) + + def clone(self, keep_blobs=True): + normalized_fields = [ + (k, _normalize_field(v, keep_blobs=keep_blobs)) + for k, v in self.fields.items() + ] + return Struct(*normalized_fields) def __getitem__(self, item): if isinstance(item, list) or isinstance(item, tuple): @@ -262,22 +334,39 @@ class Scalar(Field): blob living in a caffe2 Workspace. If blob of different types are passed, a conversion to numpy.ndarray is attempted. """ - def __init__(self, dtype=None, blob=None): - self.set(dtype, blob) + + def __init__(self, dtype=None, blob=None, metadata=None): + self._metadata = None + self.set(dtype, blob, metadata) Field.__init__(self, []) def field_names(self): return [''] + def field_type(self): + return self.dtype + def field_types(self): return [self.dtype] + def field_metadata(self): + return [self._metadata] + + def has_blobs(self): + return self._blob is not None + def field_blobs(self): assert self._blob is not None, 'Value is not set for this field.' return [self._blob] - def clone(self): - return Scalar(dtype=self._original_dtype, blob=self._blob) + def all_scalars(self): + return [self] + + def clone(self, keep_blobs=True): + return Scalar( + dtype=self._original_dtype, + blob=self._blob if keep_blobs else None, + metadata=self._metadata) def get(self): """Gets the current blob of this Scalar field.""" @@ -288,7 +377,30 @@ class Scalar(Field): """Shortcut for self.get()""" return self.get() - def set(self, dtype=None, blob=None): + @property + def metadata(self): + return self._metadata + + def set_metadata(self, value): + assert isinstance(value, Metadata), \ + 'metadata must be Metadata, got {}'.format(type(value)) + self._metadata = value + self._validate_metadata() + + def _validate_metadata(self): + if self._metadata is None: + return + if (self._metadata.categorical_limit is not None and + self.dtype is not None): + assert np.issubdtype(self.dtype, np.integer), \ + "`categorical_limit` can be specified only in integral " + \ + "fields but got {}".format(self.dtype) + + def set_value(self, blob): + """Sets only the blob field still validating the existing dtype""" + self.set(dtype=self._original_dtype, blob=blob) + + def set(self, dtype=None, blob=None, metadata=None): """Set the type and/or blob of this scalar. See __init__ for details. Args: @@ -300,6 +412,8 @@ class Scalar(Field): a conversion to numpy.ndarray is attempted. Strings aren't accepted, since they can be ambiguous. If you want to pass a string, to either BlobReference(blob) or np.array(blob). + metadata: optional instance of Metadata, if provided overrides + the metadata information of the scalar """ if blob is not None and isinstance(blob, core.basestring): raise ValueError( @@ -313,7 +427,7 @@ class Scalar(Field): # If blob is not None and it is not a BlobReference, we assume that # it is actual tensor data, so we will try to cast it to an numpy array. if blob is not None and not isinstance(blob, BlobReference): - if dtype is not None: + if dtype is not None and dtype != np.void: blob = np.array(blob, dtype=dtype.base) # if array is empty we may need to reshape a little if blob.size == 0: @@ -321,10 +435,16 @@ class Scalar(Field): else: assert isinstance(blob, np.ndarray), ( 'Invalid blob type: %s' % str(type(blob))) - assert len(blob.shape), ('Value must be at least a 1D array.') + + # reshape scalars into 1D arrays + # TODO(azzolini): figure out better way of representing this + if len(blob.shape) == 0: + blob = blob.reshape((1,)) + # infer inner shape from the blob given # TODO(dzhulgakov): tweak this to make it work with PackedStruct - if len(blob.shape) > 1: + if (len(blob.shape) > 1 and dtype is not None and + dtype.base != np.void): dtype = np.dtype((dtype.base, blob.shape[1:])) # if we were still unable to infer the dtype if dtype is None: @@ -334,10 +454,14 @@ class Scalar(Field): 'Use from_dtype instead.') self.dtype = dtype self._blob = blob + if metadata is not None: + self.set_metadata(metadata) + self._validate_metadata() def set_type(self, dtype): self._original_dtype = dtype self.dtype = np.dtype(dtype or np.void) + self._validate_metadata() def id(self): """ @@ -406,6 +530,7 @@ def from_dtype(dtype, _outer_shape=()): class _SchemaNode(object): """This is a private class used to represent a Schema Node""" + def __init__(self, name, type_str=''): self.name = name self.children = [] @@ -475,20 +600,26 @@ class _SchemaNode(object): logger.info(self.type_str) -def from_column_list(col_names, col_types=None, col_blobs=None): +def from_column_list(col_names, col_types=None, col_blobs=None, + col_metadata=None): """ Given a list of names, types, and optionally values, construct a Schema. """ if col_types is None: col_types = [None] * len(col_names) + if col_metadata is None: + col_metadata = [None] * len(col_names) if col_blobs is None: col_blobs = [None] * len(col_names) assert len(col_names) == len(col_types), ( 'col_names and col_types must have the same length.') + assert len(col_names) == len(col_metadata), ( + 'col_names and col_metadata must have the same length.') assert len(col_names) == len(col_blobs), ( 'col_names and col_blobs must have the same length.') root = _SchemaNode('root', 'Struct') - for col_name, col_type, col_blob in zip(col_names, col_types, col_blobs): + for col_name, col_type, col_blob, col_metadata in zip( + col_names, col_types, col_blobs, col_metadata): columns = col_name.split(':') current = root for i in range(len(columns)): @@ -497,7 +628,8 @@ def from_column_list(col_names, col_types=None, col_blobs=None): field = None if i == len(columns) - 1: type_str = col_type - field = Scalar(dtype=col_type, blob=col_blob) + field = Scalar(dtype=col_type, blob=col_blob, + metadata=col_metadata) next = current.add_child(name, type_str) if field is not None: next.field = field @@ -515,31 +647,63 @@ def from_blob_list(schema, values): assert isinstance(schema, Field), 'Argument `schema` must be a Field.' if isinstance(values, BlobReference): values = [values] - names = schema.field_names() - types = schema.field_types() - assert len(names) == len(values), ( - 'Values must have %d elements, got %d.' % (len(names), len(values))) - return from_column_list(names, types, values) + record = schema.clone_schema() + scalars = record.all_scalars() + assert len(scalars) == len(values), ( + 'Values must have %d elements, got %d.' % (len(scalars), len(values))) + for scalar, value in zip(scalars, values): + scalar.set_value(value) + return record -def FetchRecord(blob_record): +def as_record(value): + if isinstance(value, Field): + return value + elif isinstance(value, list) or isinstance(value, tuple): + is_field_list = all( + f is tuple and len(f) == 2 and isinstance(f[0], core.basestring) + for f in value) + if is_field_list: + return Struct(*[(k, as_record(v)) for k, v in value]) + else: + return Tuple(*[as_record(f) for f in value]) + elif isinstance(value, dict): + return Struct(*[(k, as_record(v)) for k, v in value.items()]) + else: + return _normalize_field(value) + + +def FetchRecord(blob_record, ws=None): """ Given a record containing BlobReferences, return a new record with same schema, containing numpy arrays, fetched from the current active workspace. """ + def fetch(v): + if ws is None: + return workspace.FetchBlob(str(v)) + else: + return ws.blobs[str(v)].fetch() + assert isinstance(blob_record, Field) field_blobs = blob_record.field_blobs() assert all(isinstance(v, BlobReference) for v in field_blobs) - field_arrays = [workspace.FetchBlob(value) for value in field_blobs] + field_arrays = [fetch(value) for value in field_blobs] return from_blob_list(blob_record, field_arrays) -def FeedRecord(blob_record, arrays): +def FeedRecord(blob_record, arrays, ws=None): """ Given a Record containing blob_references and arrays, which is either a list of numpy arrays or a Record containing numpy arrays, feeds the record to the current workspace. """ + def feed(b, v): + if ws is None: + workspace.FeedBlob(str(b), v) + else: + ws.create_blob(str(b)) + ws.blobs[str(b)].feed(v) + assert isinstance(blob_record, Field) field_blobs = blob_record.field_blobs() assert all(isinstance(v, BlobReference) for v in field_blobs) @@ -549,7 +713,7 @@ def FeedRecord(blob_record, arrays): assert len(arrays) == len(field_blobs), ( 'Values must contain exactly %d ndarrays.' % len(field_blobs)) for blob, array in zip(field_blobs, arrays): - workspace.FeedBlob(blob, array) + feed(blob, array) def NewRecord(net, schema): @@ -558,12 +722,41 @@ def NewRecord(net, schema): returning a record containing BlobReferences. The BlobReferences will be added as ExternalInputs of the given net. """ + if isinstance(schema, Scalar): + result = schema.clone() + result.set_value(blob=BlobReference(net.NextName('unnamed_scalar'))) + return result + assert isinstance(schema, Field), 'Record must be a schema.Field instance.' blob_refs = [ net.AddExternalInput(net.NextName(prefix=name)) for name in schema.field_names()] return from_blob_list(schema, blob_refs) + +def ConstRecord(net, array_record): + """ + Given a record of arrays, returns a record of blobs, + initialized with net.Const. + """ + blob_record = NewRecord(net, array_record) + for blob, array in zip( + blob_record.field_blobs(), + array_record.field_blobs()): + net.Const(array, blob) + return blob_record + + +def InitEmptyRecord(net, schema_or_record): + if not schema_or_record.has_blobs(): + record = NewRecord(net, schema_or_record) + else: + record = schema_or_record + for blob in record.field_blobs(): + net.ConstantFill([], blob, shape=[0]) + return record + + _DATA_TYPE_FOR_DTYPE = [ (np.str, core.DataType.STRING), (np.float32, core.DataType.FLOAT), @@ -578,6 +771,27 @@ _DATA_TYPE_FOR_DTYPE = [ ] +def is_schema_subset(schema, original_schema): + # TODO add more checks + return set(schema.field_names()).issubset( + set(original_schema.field_names())) + + +def equal_schemas(schema, original_schema): + assert isinstance(schema, Field) + assert isinstance(original_schema, Field) + # TODO allow for more compatibility + return schema.field_names() == original_schema.field_names() and\ + schema.field_types() == original_schema.field_types() + + +def schema_check(schema, previous=None): + record = as_record(schema) + if previous is not None: + assert equal_schemas(schema, previous) + return record + + def data_type_for_dtype(dtype): for np_type, dt in _DATA_TYPE_FOR_DTYPE: if dtype.base == np_type: diff --git a/caffe2/python/schema_test.py b/caffe2/python/schema_test.py index aea2c80db783..a8e75b614961 100644 --- a/caffe2/python/schema_test.py +++ b/caffe2/python/schema_test.py @@ -14,8 +14,7 @@ class TestDB(unittest.TestCase): def testPicklable(self): s = schema.Struct( ('field1', schema.Scalar(dtype=np.int32)), - ('field2', schema.List( - schema.Scalar(dtype=str))) + ('field2', schema.List(schema.Scalar(dtype=str))) ) s2 = pickle.loads(pickle.dumps(s)) for r in (s, s2): @@ -60,10 +59,10 @@ class TestDB(unittest.TestCase): def testRawTuple(self): s = schema.RawTuple(2) self.assertEquals( - s, - schema.Struct( - ('field_0', schema.Scalar()), - ('field_1', schema.Scalar()))) + s, schema.Struct( + ('field_0', schema.Scalar()), ('field_1', schema.Scalar()) + ) + ) self.assertEquals(s[0], schema.Scalar()) self.assertEquals(s[1], schema.Scalar()) @@ -81,3 +80,34 @@ class TestDB(unittest.TestCase): ('field1', schema.Scalar(dtype=np.int32)), ) ) + + def testPreservesMetadata(self): + s = schema.Struct( + ('a', schema.Scalar(np.float32)), ( + 'b', schema.Scalar( + np.int32, + metadata=schema.Metadata(categorical_limit=5) + ) + ) + ) + self.assertEqual(None, s.a.metadata) + self.assertEqual(5, s.b.metadata.categorical_limit) + sc = s.clone() + self.assertEqual(None, sc.a.metadata) + self.assertEqual(5, sc.b.metadata.categorical_limit) + sv = schema.from_blob_list(s, [np.array([3.4]), np.array([2])]) + self.assertEqual(None, sv.a.metadata) + self.assertEqual(5, sv.b.metadata.categorical_limit) + + def testPreservesEmptyFields(self): + s = schema.Struct( + ('a', schema.Scalar(np.float32)), + ('b', schema.Struct()), + ) + sc = s.clone() + self.assertIn("a", sc.fields) + self.assertIn("b", sc.fields) + sv = schema.from_blob_list(s, [np.array([3.4])]) + self.assertIn("a", sv.fields) + self.assertIn("b", sv.fields) + self.assertEqual(0, len(sv.b.fields)) diff --git a/caffe2/python/scope.py b/caffe2/python/scope.py index 497507c80328..2baee93a75e0 100644 --- a/caffe2/python/scope.py +++ b/caffe2/python/scope.py @@ -4,6 +4,7 @@ from __future__ import print_function from __future__ import unicode_literals import contextlib +import threading from caffe2.proto import caffe2_pb2 @@ -15,38 +16,51 @@ except NameError: basestring = str # The name scope and device scope when creating a new operator. -NAMESCOPE = '' -DEVICESCOPE = None - _NAMESCOPE_SEPARATOR = '/' +_threadlocal_scope = threading.local() + + +def CurrentNameScope(): + global _threadlocal_scope + if not hasattr(_threadlocal_scope, "namescope"): + _threadlocal_scope.namescope = '' + return _threadlocal_scope.namescope + + +def CurrentDeviceScope(): + global _threadlocal_scope + if not hasattr(_threadlocal_scope, "devicescope"): + _threadlocal_scope.devicescope = None + return _threadlocal_scope.devicescope + # NOTE: using NameScope is NOT thread-safe! (TODO t13621185) @contextlib.contextmanager def NameScope(prefix, reset=False): - global NAMESCOPE + global _threadlocal_scope assert isinstance(prefix, basestring), \ "NameScope takes in a string as its argument." - old_scope = NAMESCOPE + old_scope = CurrentNameScope() prefix = prefix + _NAMESCOPE_SEPARATOR if prefix is not '' else '' if reset: - NAMESCOPE = prefix + _threadlocal_scope.namescope = prefix else: - NAMESCOPE = NAMESCOPE + prefix + _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix yield - assert NAMESCOPE.endswith(prefix), \ + assert _threadlocal_scope.namescope.endswith(prefix), \ "The namescope variable is changed from outside NameScope() calls." - NAMESCOPE = old_scope + _threadlocal_scope.namescope = old_scope @contextlib.contextmanager def DeviceScope(scope): assert isinstance(scope, caffe2_pb2.DeviceOption), \ "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument." - global DEVICESCOPE - old_scope = DEVICESCOPE - DEVICESCOPE = scope + global _threadlocal_scope + old_scope = CurrentDeviceScope() + _threadlocal_scope.devicescope = scope yield - assert DEVICESCOPE == scope, \ + assert _threadlocal_scope.devicescope == scope, \ "The device scope is changed from outside DeviceScope() calls." - DEVICESCOPE = old_scope + _threadlocal_scope.devicescope = old_scope diff --git a/caffe2/python/scope_test.py b/caffe2/python/scope_test.py new file mode 100644 index 000000000000..b1e2ef5aa6a1 --- /dev/null +++ b/caffe2/python/scope_test.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import scope, core +from caffe2.proto import caffe2_pb2 + +import unittest +import threading +import time + +SUCCESS_COUNT = 0 + + +def thread_runner(idx, testobj): + global SUCCESS_COUNT + testobj.assertEquals(scope.CurrentNameScope(), "") + testobj.assertEquals(scope.CurrentDeviceScope(), None) + namescope = "namescope_{}".format(idx) + dsc = core.DeviceOption(caffe2_pb2.CUDA, idx) + with scope.DeviceScope(dsc): + with scope.NameScope(namescope): + testobj.assertEquals(scope.CurrentNameScope(), namescope + "/") + testobj.assertEquals(scope.CurrentDeviceScope(), dsc) + + time.sleep(0.01 + idx * 0.01) + testobj.assertEquals(scope.CurrentNameScope(), namescope + "/") + testobj.assertEquals(scope.CurrentDeviceScope(), dsc) + + testobj.assertEquals(scope.CurrentNameScope(), "") + testobj.assertEquals(scope.CurrentDeviceScope(), None) + SUCCESS_COUNT += 1 + + +class TestScope(unittest.TestCase): + + def testNamescopeBasic(self): + self.assertEquals(scope.CurrentNameScope(), "") + + with scope.NameScope("test_scope"): + self.assertEquals(scope.CurrentNameScope(), "test_scope/") + + self.assertEquals(scope.CurrentNameScope(), "") + + def testDevicescopeBasic(self): + self.assertEquals(scope.CurrentDeviceScope(), None) + + dsc = core.DeviceOption(caffe2_pb2.CUDA, 9) + with scope.DeviceScope(dsc): + self.assertEquals(scope.CurrentDeviceScope(), dsc) + + self.assertEquals(scope.CurrentDeviceScope(), None) + + def testMultiThreaded(self): + """ + Test that name/device scope are properly local to the thread + and don't interfere + """ + global SUCCESS_COUNT + self.assertEquals(scope.CurrentNameScope(), "") + self.assertEquals(scope.CurrentDeviceScope(), None) + + threads = [] + for i in range(4): + threads.append(threading.Thread( + target=thread_runner, + args=(i, self), + )) + for t in threads: + t.start() + + with scope.NameScope("master"): + self.assertEquals(scope.CurrentDeviceScope(), None) + self.assertEquals(scope.CurrentNameScope(), "master/") + for t in threads: + t.join() + + self.assertEquals(scope.CurrentNameScope(), "master/") + self.assertEquals(scope.CurrentDeviceScope(), None) + + # Ensure all threads succeeded + self.assertEquals(SUCCESS_COUNT, 4) diff --git a/caffe2/python/session.py b/caffe2/python/session.py new file mode 100644 index 000000000000..8b8a1f616e44 --- /dev/null +++ b/caffe2/python/session.py @@ -0,0 +1,147 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +from caffe2.python import core, workspace +from caffe2.python.task import Task, TaskGroup, WorkspaceType + + +class Session(object): + """ + Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups. + A session can potentially run in multiple nodes concurrently. + + + Example: + from core import Net + from caffe2.python.task import Task, TaskGroup, WorkspaceType + + net = Net('test1') + net.Add([net.Const(1), net.Const(2)]) + + net2 = net.Clone() + step = core.execution_step('step1', [net2]) + + with TaskGroup(WorkspaceType.GLOBAL) as init_tg: + with Node('node1'): + n1setup = net.Net('n1setup') + n1msg = n1setup.Const('Hello from node 1.') + Task(step=n1setup) + + with TaskGroup() as private_tg: + with Node('node1'): + n1 = net.Net('n1') + n1.Print(n1msg, 0) + Task(step=n1) + with Node('node2'): + n2 = net.Net('n2') + n2.Print(n2.Const('Hello from node 2.'), 0) + Task(step=n2) + + session = LocalSession() + session.run(net) + session.run(step) + session.run(init_tg) + session.run(private_tg) + + + Global Workspace: + At the beggining of the session, a global workspace is created and kept + alive for the duration of the session. + + + Private Workspace: + Tasks can be run either directly on the global workspace, or they can + instantiate a private child workspace that is released after each run. + + Blob visibility: + Tasks running in different nodes in parallel will always run under + different workspaces, so it must be assumed that they won't be able to + access each other's blobs. On the other hand, tasks running on the same + node are guaranteed to run on the same workspace within a run. + """ + def __init__(self): + self._open = True + self._runnable_cache = {} + + def is_open(self): + return self._open + + def run(self, runnable): + assert self.is_open(), 'Session is closed.' + if runnable not in self._runnable_cache: + if isinstance(runnable, TaskGroup): + tg = runnable + else: + tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL) + if isinstance(runnable, Task): + tg.add(runnable) + elif isinstance(runnable, core.ExecutionStep): + tg.add(Task(step=runnable)) + else: + step = core.execution_step('runnable', runnable) + tg.add(Task(step=step)) + self._runnable_cache[runnable] = tg + self._run_task_group(self._runnable_cache[runnable]) + + def close(self): + if self.is_open(): + self._do_close() + self._open = False + + def fetch_output(self, output): + raise NotImplementedError() + + def _run_task_group(self, task_group): + raise NotImplementedError() + + def _do_close(self): + pass + + def __enter__(self): + assert self._open, 'Session already closed.' + return self + + def __exit__(self, ex_type, value, traceback): + if ex_type is None: + self.close() + + +class LocalSession(Session): + """ + Session that runs in a single node. + Tasks are all remapped to run in parallel in the 'local' node. + + Currently, LocalSession runs all parallel tasks in the same workspace, + but this behavior may change in the future. Only tasks pointing to the + same logical node are guaranteed to always run in the same workspace. + """ + def __init__(self, ws): + Session.__init__(self) + self._ws = ws + self._plan_caches = {} + + def _run_task_group(self, task_group): + if task_group not in self._plan_caches: + task = task_group.to_task() + plan = core.Plan('task_group_plan') + plan.AddStep(task.get_step()) + self._plan_caches[task_group] = (plan, task) + plan, task = self._plan_caches[task_group] + + # make sure the output blobs belong to the parent workspace + outputs = [] + for name in task.output_names(): + self._ws.create_blob(str(name)) + outputs.append(core.BlobReference(str(name))) + task.set_outputs(outputs, _fetch_func=self._fetch_output) + task_ws = ( + workspace.C.Workspace(self._ws) + if task.workspace_type == WorkspaceType.PRIVATE else self._ws) + with workspace.WorkspaceGuard(task_ws): + task_ws.run(plan) + + def _fetch_output(self, output): + return self._ws.blobs[str(output)].fetch() diff --git a/caffe2/python/session_test.py b/caffe2/python/session_test.py new file mode 100644 index 000000000000..5466bac70c5e --- /dev/null +++ b/caffe2/python/session_test.py @@ -0,0 +1,60 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python.schema import ( + Struct, FetchRecord, NewRecord, FeedRecord, InitEmptyRecord) +from caffe2.python import core, workspace +from caffe2.python.session import LocalSession +from caffe2.python.dataset import Dataset +from caffe2.python.pipeline import pipe +from caffe2.python.task import TaskGroup +from caffe2.python.test_util import TestCase +import numpy as np + + +class TestLocalSession(TestCase): + def test_local_session(self): + init_net = core.Net('init') + src_values = Struct( + ('uid', np.array([1, 2, 6])), + ('value', np.array([1.4, 1.6, 1.7]))) + expected_dst = Struct( + ('uid', np.array([2, 4, 12])), + ('value', np.array([0.0, 0.0, 0.0]))) + + src_blobs = NewRecord(init_net, src_values) + dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema()) + + def proc1(rec): + net = core.Net('proc1') + out = NewRecord(net, rec) + net.Add([rec.uid(), rec.uid()], [out.uid()]) + out.value.set(blob=rec.value()) + return [net], out + + def proc2(rec): + net = core.Net('proc2') + out = NewRecord(net, rec) + out.uid.set(blob=rec.uid()) + net.Sub([rec.value(), rec.value()], [out.value()]) + return [net], out + + src_ds = Dataset(src_blobs) + dst_ds = Dataset(dst_blobs) + + with TaskGroup() as tg: + out1 = pipe(src_ds.reader(), processor=proc1) + out2 = pipe(out1, processor=proc2) + pipe(out2, dst_ds.writer()) + + ws = workspace.C.Workspace() + FeedRecord(src_blobs, src_values, ws) + session = LocalSession(ws) + session.run(init_net) + session.run(tg) + output = FetchRecord(dst_blobs, ws=ws) + + for a, b in zip(output.field_blobs(), expected_dst.field_blobs()): + np.testing.assert_array_equal(a, b) diff --git a/caffe2/python/snapshot.py b/caffe2/python/snapshot.py new file mode 100644 index 000000000000..6f350e83554d --- /dev/null +++ b/caffe2/python/snapshot.py @@ -0,0 +1,263 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import logging +from caffe2.python import core, context +from caffe2.python.task import Node, Task, TaskGroup, TaskOutput, WorkspaceType + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@context.define_context() +class Job(object): + """ + A Job defines two TaskGroups: the `init_group` and the `epoch_group`, which + will be run by a JobRunner. + + The `init_group` will be run only once at startup. Its role is to + initialize globally persistent blobs such as model weights, accumulators + and data file lists. + + The `epoch_group` will be run in a loop after init_group. The loop will + exit when any of the stop signals added with `add_stop_signal` is True + at the end of an epoch. + + Jobs are context-driven, so that Tasks can be added to the active Job + without having to explicitly pass the job object around. + + Example of usage: + + def build_reader(partitions): + with Job.current().init_group: + reader = HiveReader(init_reader, ..., partitions) + Task(step=init_reader) + with Job.current().epoch_group: + limited_reader = ReaderWithLimit(reader, num_iter=10000) + data_queue = pipe(limited_reader, num_threads=8) + Job.current().add_stop_signal(limited_reader.data_finished()) + return data_queue + + def build_hogwild_trainer(reader, model): + with Job.current().init_group: + Task(step=model.param_init_net) + with Job.current().epoch_group: + pipe(reader, processor=model, num_threads=8) + + with Job() as job: + reader = build_reader(partitions) + model = build_model(params) + build_hogwild_trainer(reader, model) + """ + def __init__(self): + self.init_group = TaskGroup(workspace_type=WorkspaceType.GLOBAL) + self.epoch_group = TaskGroup() + self.stop_signals = [] + + def __enter__(self): + self.epoch_group.__enter__() + return self + + def __exit__(self, *args): + self.epoch_group.__exit__() + + def add_stop_signal(self, output): + if isinstance(output, core.BlobReference): + t = Task(outputs=[output], group=self.epoch_group) + output = t.outputs()[0] + assert isinstance(output, TaskOutput) + self.stop_signals.append(output) + + +class SnapshotManager(object): + """ + Controls saving and loading of workspaces on every epoch boundary of a job. + If a SnapshotManager instance is passed to JobRunner, then JobRunner will + call `init`, `read` and `save` at different moments in between epoch runs. + """ + def __init__(self, db, db_type): + self._db = db + self._db_type = db_type + # make sure these blobs are the first in the snapshot file. + self._net = core.Net('!!snapshot_mngr') + self._blob_names = self._net.AddExternalInput('blob_names') + self._names_output = None + + def init(self, nodes=None, retrieve_from_epoch=None): + """ + Build a Task that will be run once after the job's `init_group` is run. + This task will determine which blobs need to be snapshoted. + If retrieve_from_epoch is not None, then the snapshot metadata is + retrieved from a previously saved snapshot. + """ + assert nodes is None or len(nodes) == 1, ( + 'SnapshotManager only supports single node.') + net = core.Net('get_blob_list') + if retrieve_from_epoch is None: + net.GetAllBlobNames( + [], + self._blob_names, + include_shared=False) + else: + net.Load( + [], self._blob_names, + db=self._dbname(retrieve_from_epoch), + db_type=self._db_type, + absolute_path=True) + task = Task(step=net, outputs=[self._blob_names]) + self._names_output = task.outputs()[0] + return task + + def blob_list(self): + assert self._names_output + return self._names_output.fetch().tolist() + + def _dbname(self, epoch): + return '%s.%06d' % (self._db, epoch) + + def load(self, epoch): + """ + Build a Task that will be run by JobRunner when the job is to be + resumed from a given epoch. This task will run a Load op that will + load and deserialize all relevant blobs from a persistent storage. + """ + net = core.Net('get_blob_list') + net.Load( + [], + self.blob_list(), + db=self._dbname(epoch), + db_type=self._db_type, + absolute_path=True) + return Task(step=net) + + def save(self, epoch): + """ + Build a Task that is run once after `init_group` and after each + epoch is run. This will execute a Save ops to serialize and persist + blobs present in the global workspaace. + """ + net = core.Net('snapshot_save') + net.Save( + self.blob_list(), [], db=self._dbname(epoch), + db_type=self._db_type, absolute_path=True) + return Task(step=net) + + +class MultiNodeSnapshotManager(object): + """ + Coordinates snapshoting and checkpointing across multiple nodes. + Each of `init`, `load` and `save` will build TaskGroups which will + trigger snapshotting on each of the nodes involved in a distributed job. + """ + def __init__(self, db_prefix, db_type, node_manager_class=SnapshotManager): + self._node_manager_class = node_manager_class + self._node_managers = None + self._db_prefix = db_prefix + self._db_type = db_type + + def _task_group(self, func, *args, **kw): + assert self._node_managers is not None, 'init must be called first.' + with TaskGroup(WorkspaceType.GLOBAL) as task_group: + for node, manager in self._node_managers: + with Node(node): + func(manager, *args, **kw) + return task_group + + def init(self, nodes, retrieve_from_epoch=None): + if self._node_managers is not None: + assert [node for node, _ in self._node_managers] == nodes + return + self._node_managers = [] + for node in nodes: + with Node(node): + manager = self._node_manager_class( + db=os.path.join(self._db_prefix, node), + db_type=self._db_type) + self._node_managers.append((node, manager)) + return self._task_group( + self._node_manager_class.init, + nodes=[node], + retrieve_from_epoch=retrieve_from_epoch) + + def load(self, epoch): + return self._task_group(self._node_manager_class.load, epoch) + + def save(self, epoch): + return self._task_group(self._node_manager_class.save, epoch) + + +class JobRunner(object): + """ + Implement the runtime logic for jobs with checkpointing at the level of + epoch. Can be used to run either single-host or distributed jobs. Job + runner is a callable to be called once from the client, passing a Session + as argument. This call will block until the Job execution is complete. + + If a snapshot_manager is passed, snapshots will be taken after + initialization and after each epoch execution. If, in addition, + `resume_from_epoch` is an epoch number, the corresponding snapshot will + be loaded and job execution will continue from the given epoch. In + this case, the job's init_group will not be run. + + Refer to snapshot_test.py for an example. + """ + def __init__(self, job, snapshot_manager=None, resume_from_epoch=None): + self.resume_from_epoch = resume_from_epoch + self.snapshot = snapshot_manager + self.job = job + + def __call__(self, client): + from_scratch = self.resume_from_epoch is None + if from_scratch: + client.run(self.job.init_group) + + if self.snapshot: + logger.info('Preparing snapshot ...') + client.run(self.snapshot.init( + self.job.init_group.used_nodes(), + retrieve_from_epoch=self.resume_from_epoch)) + if from_scratch: + logger.info('Saving first snapshot ...') + client.run(self.snapshot.save(0)) + logger.info('First snapshot saved.') + else: + logger.info('Loading snapshot for epoch {} ...'.format( + self.resume_from_epoch)) + client.run(self.snapshot.load(self.resume_from_epoch)) + logger.info('Snapshot loaded.') + + epoch = 1 if from_scratch else self.resume_from_epoch + 1 + while True: + logger.info('Starting epoch %d.' % epoch) + client.run(self.job.epoch_group) + logger.info('Ran epoch %d.' % epoch) + stop_signals = [o.fetch() for o in self.job.stop_signals] + + if self.snapshot: + logger.info('Saving snapshot ...') + client.run(self.snapshot.save(epoch)) + logger.info('Snapshot saved.') + + if any(stop_signals): + logger.info('Stopping.') + break + epoch += 1 + return epoch + + +def epoch_limiter(num_epochs): + """ + Creates a task that will output True when a given + number of epochs has finished. + """ + with Job.current().init_group: + init_net = core.Net('epoch_counter_init') + counter = init_net.CreateCounter([], init_count=num_epochs - 1) + Task(step=init_net) + epoch_net = core.Net('epoch_countdown') + finished = epoch_net.CountDown(counter) + output = Task(step=epoch_net, outputs=finished).outputs()[0] + Job.current().add_stop_signal(output) diff --git a/caffe2/python/snapshot_test.py b/caffe2/python/snapshot_test.py new file mode 100644 index 000000000000..5f305f67797b --- /dev/null +++ b/caffe2/python/snapshot_test.py @@ -0,0 +1,95 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python.schema import Struct, ConstRecord +from caffe2.python import core, workspace +from caffe2.python.session import LocalSession +from caffe2.python.dataset import Dataset +from caffe2.python.pipeline import pipe +from caffe2.python.snapshot import ( + SnapshotManager, MultiNodeSnapshotManager, Job, JobRunner) +from caffe2.python.task import Task, Node +from caffe2.python.test_util import TestCase +from caffe2.python.dataio import ReaderWithLimit +import tempfile +import numpy as np +import shutil + + +def build_job(): + with Node('reader'): + with Job() as job: + with job.init_group: + init_net = core.Net('init_net') + data_arr = Struct(('val', np.array(range(10)))) + data = ConstRecord(init_net, data_arr) + ds = Dataset(data) + full_reader = ds.reader(init_net) + total = init_net.Const([100]) + Task(step=init_net) + + def inc_total(rec): + net = core.Net('inc_total') + net.Add([total, rec.val()], [total]) + return [net] + + epoch_reader = ReaderWithLimit(full_reader, num_iter=3) + pipe(epoch_reader, processor=inc_total) + job.add_stop_signal(epoch_reader.data_finished()) + + total_fetcher = Task(step=core.Net('empty'), outputs=[total]) + return job, total_fetcher + + +EXPECTED_TOTALS = [103, 115, 136, 145] + + +class TestSnapshot(TestCase): + def run_with(self, builder): + job, output_fetcher = build_job() + + def fetch_total(session): + session.run(output_fetcher) + return output_fetcher.outputs()[0].fetch() + + session, snapshot = builder() + num_epochs = JobRunner(job, snapshot)(session) + self.assertEquals(num_epochs, len(EXPECTED_TOTALS)) + self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1]) + + for initial_epoch in range(1, num_epochs + 1): + session, snapshot = builder() + JobRunner(job, snapshot, resume_from_epoch=initial_epoch)(session) + self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1]) + + for epoch in range(1, num_epochs + 1): + session.run(snapshot.load(epoch)) + self.assertEquals(fetch_total(session), EXPECTED_TOTALS[epoch - 1]) + + def test_single_snapshot(self): + # test single node + with tempfile.NamedTemporaryFile() as tmp: + + def builder(): + ws = workspace.C.Workspace() + session = LocalSession(ws) + snapshot = SnapshotManager(tmp.name, 'minidb') + return session, snapshot + + self.run_with(builder) + + # test multi-node + try: + tmpdir = tempfile.mkdtemp() + + def builder(): + ws = workspace.C.Workspace() + session = LocalSession(ws) + snapshot = MultiNodeSnapshotManager(tmpdir, 'minidb') + return session, snapshot + + self.run_with(builder) + finally: + shutil.rmtree(tmpdir) diff --git a/caffe2/python/task.py b/caffe2/python/task.py new file mode 100644 index 000000000000..78f6a79c3122 --- /dev/null +++ b/caffe2/python/task.py @@ -0,0 +1,482 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, context +from caffe2.python.schema import Field, from_blob_list +from collections import defaultdict + + +@context.define_context(allow_default=True) +class Cluster(object): + """ + Context that keeps track of all the node names used. + Users shouldn't have to use them directly, since a Cluster is automatically + generated at the first usage of 'Node'. + """ + + def __init__(self): + # list instead of set to keep order + self._nodes = [] + + def add_node(self, node): + if str(node) not in self._nodes: + self._nodes.append(str(node)) + + def nodes(self): + """ + Returns the list of unique node names used within this context. + """ + return self._nodes + + +@context.define_context(allow_default=True) +class Node(object): + """ + A Node context is used to indicate that all Tasks instantiated within will + run on the given node name. (Only the name of the node actually counts.) + Example: + + with TaskGroup() as tg: + with Node('node1'): + s1 = execution_step(...) + Task(step=s1) + with Node('node2'): + s2 = execution_step(...) + with Node('node1'): + s3 = execution_step(...) + + In this example, all three execution steps will run in parallel. + Moreover, s1 and s3 will run on the same node, and can see each + others blobs. + """ + + def __init__(self, node='local'): + self._name = str(node) + Cluster.current().add_node(self) + + def __str__(self): + return self._name + + +class WorkspaceType(object): + """ + Determines whether tasks of a TaskGroup will run directly at the global + workspace, which is kept alive across runs, or whether a new child + workspace will be created for the run and destroyed afterwards. + """ + PRIVATE = 'private' + GLOBAL = 'global' + + +def get_setup_nets(key, steps, target): + init_net = core.Net(key + '/init') + exit_net = core.Net(key + '/exit') + init_nets = [] + exit_nets = [] + objs = set() + for step in steps: + if step is not None: + objs |= set(step.get_all_attributes(key)) + for obj in objs: + # these are needed in order to allow nesting of TaskGroup, which + # is a feature not yet implemented. + if hasattr(obj, '_setup_used') and obj._setup_used: + continue + if hasattr(obj, '_setup_target') and obj._setup_target != target: + continue + if hasattr(obj, 'setup'): + nets = obj.setup(init_net) + if isinstance(nets, (list, tuple)): + init_nets += nets + elif isinstance(nets, core.Net): + init_nets.append(nets) + elif nets is not None: + raise TypeError('Unsupported type for setup: %s' % type(nets)) + obj._setup_used = True + if hasattr(obj, 'exit'): + nets = obj.exit(exit_net) + if isinstance(nets, (list, tuple)): + exit_nets += nets + elif isinstance(nets, core.Net): + exit_nets.append(nets) + elif nets is not None: + raise TypeError('Unsupported type for setup: %s' % type(nets)) + obj._setup_used = True + + if len(init_net.Proto().op) > 0: + init_nets.insert(0, init_net) + if len(exit_net.Proto().op) > 0: + exit_nets.insert(0, exit_net) + return init_nets, exit_nets + + +@context.define_context(allow_default=False) +class TaskGroup(object): + """ + Context that gathers tasks which will run concurrently, potentially on + multiple nodes. All tasks in the same node will share the same workspace + and thus can share blobs, while tasks running in different nodes won't + be able to directly share data. + + All tasks of the task group will start concurrently, and the task group + will finish execution when the last task of the group finishes. + + Example: + # supose that s1 ... s5 are execution steps or nets. + with TaskGroup() as tg: + # these tasks go to default node 'local' + Task(step=s1) + Task(step=s2) + + with Node('n2'): + Task(step=s3) + with Node('n1'): + Task(step=s4) + with Node('n2'): + Task(step=s5) + + # this will run all steps in parallel. + # s1 and s2 will run at default node 'local' + # s3 and s5 will run at node 'n2' + # s4 will run at node 'n1' + session.run(tg) + """ + LOCAL_SETUP = 'local_setup' + + def __init__(self, workspace_type=None): + self._plan_cache = None + self._tasks = [] + self._already_used = False + self._prev_active = None + self._tasks_to_add = [] + self._report_nets = {} + self._workspace_type = workspace_type + self._tasks_by_node = None + + def add(self, task): + assert not self._already_used, ( + 'Cannot add Task to an already used TaskGroup.') + assert ( + self._workspace_type is None or + task._workspace_type is None or + self._workspace_type == task._workspace_type) + if task._workspace_type is None: + task._workspace_type = ( + self._workspace_type or WorkspaceType.PRIVATE) + if self._workspace_type is None: + self._workspace_type = task._workspace_type + task._notify_used() + self._tasks.append(task) + + def tasks(self): + for task in self._tasks_to_add: + self.add(task) + self._tasks_to_add = [] + self._already_used = True + return self._tasks + + def num_registered_tasks(self): + return len(self._tasks_to_add) + len(self._tasks) + + def used_nodes(self): + # use list to keep order + used = [] + for task in self.tasks(): + if task.node not in used: + used.append(task.node) + return used + + def report_net(self, net=None, node=None, report_interval=5): + """ + Get or set the `report_net`, which is a net that runs repeatedly every + `report_interval` seconds for the duration of the TaskGroup execution + on each of the nodes. Each node has it's own report net. + + Example: + + with TaskGroup() as tg: + for i in range(0, 2): + with Node('trainer:%d' % i): + report_net = tg.report_net() + report_net.LogInfo('5s passed in trainer %d' % i) + + This will print '5s passed in trainer {}' every 5s on each one of the + trainer nodes. + """ + node = str(Node.current(node)) + assert net is None or node not in self._report_nets + if node not in self._report_nets: + self._report_nets[node] = ( + net if net else core.Net('%s/reporter' % node), + report_interval) + return self._report_nets[node][0] + + def tasks_by_node(self, node_remap=None): + # tasks_by_node can't be called twice because the setup won't + # work properly a second time. + node_map = {} + for task in self.tasks(): + node_map[task.node] =\ + node_remap(task.node) if node_remap else task.node + if self._tasks_by_node is not None: + tasks_by_node, prev_node_map = self._tasks_by_node + assert prev_node_map == node_map, ( + 'Cannot call tasks_by_node multiple times.') + return tasks_by_node + + tasks_by_node = defaultdict(list) + for task in self.tasks(): + tasks_by_node[node_map[task.node]].append(task) + grouped_by_node = TaskGroup() + for node, tasks in tasks_by_node.items(): + node_inits, node_exits = get_setup_nets( + TaskGroup.LOCAL_SETUP, [t.get_step() for t in tasks], self) + # shortcut for single task with no queue + steps = [] + outputs = [] + workspace_type = tasks[0].workspace_type() + for task in tasks: + step = task.get_step() + if step is not None: + steps.append(step) + outputs += task.outputs() + assert workspace_type == task.workspace_type(), ( + 'All tasks for a given node need same workspace type.') + if len(steps) == 0: + steps.append(core.execution_step('empty', [])) + if len(steps) == 1: + step = steps[0] + else: + step = core.execution_step( + '%s:body' % node, steps, concurrent_substeps=True) + if node in self._report_nets: + net, interval = self._report_nets[node] + step.SetReportNet(net, interval) + if len(node_inits) > 0 or len(node_exits) > 0: + steps = [] + if len(node_inits) > 0: + steps.append( + core.execution_step('%s:init' % node, node_inits)) + steps.append(step) + if len(node_exits) > 0: + steps.append( + core.execution_step('%s:exit' % node, node_exits)) + step = core.execution_step(node, steps) + Task( + node=node, step=step, outputs=outputs, + group=grouped_by_node, workspace_type=workspace_type) + self._tasks_by_node = (grouped_by_node, node_map) + return grouped_by_node + + def to_task(self, node='local'): + return self.tasks_by_node(lambda x: 'local').tasks()[0] + + +class TaskOutput(object): + """ + Represents the output of a task. An output can be a blob, + a list of blob, or a record. + """ + + def __init__(self, names): + self._schema = None + self._is_scalar = False + if isinstance(names, Field): + self._schema = names + names = self._schema.field_blobs() + self._is_scalar = type(names) not in (tuple, list) + if self._is_scalar: + names = [names] + self.names = names + self._values = None + + def set(self, values, _fetch_func=None): + assert len(values) == len(self.names) + self._values = values + self._fetch_func = _fetch_func + + def get(self): + assert self._values is not None, 'Output value not set yet.' + if self._is_scalar: + return self._values[0] + elif self._schema: + return from_blob_list(self._schema, self._values) + else: + return self._values + + def fetch(self): + assert self._fetch_func is not None, ( + 'Cannot fetch value for this output.') + fetched_vals = [self._fetch_func(v) for v in self._values] + if self._is_scalar: + return fetched_vals[0] + elif self._schema: + return from_blob_list(self._schema, fetched_vals) + else: + return fetched_vals + + +def final_output(blob_or_record): + """ + Create a dummy task that returns the given blob or record + to the client. This will return the value of the blob or record when + the last task of the TaskGroup for a given node finishes. + """ + return Task(outputs=blob_or_record).outputs()[0] + + +class Task(object): + """ + A Task is composed of an execution step and zero or more outputs. + Tasks are executed in the context of a TaskGroup, which, in turn, can + be run by a Session. + + Task outputs are fetched by the session at the end of the run. + """ + + TASK_SETUP = 'task_setup' + + def __init__( + self, step=None, outputs=None, + workspace_type=None, group=None, node=None): + """ + Instantiate a Task and add it to the current TaskGroup and Node. + """ + # register this node name with active context + self.node = str(Node.current(None if node is None else Node(node))) + group = TaskGroup.current(group, required=False) + if group is not None: + group._tasks_to_add.append(self) + + self._already_used = False + self._step = None + self._step_with_setup = None + self._outputs = [] + if step is not None: + self.set_step(step) + if outputs is not None: + self.add_outputs(outputs) + + self._pipeline = None + self._is_pipeline_context = False + self._workspace_type = workspace_type + + def workspace_type(self): + return self._workspace_type + + def _assert_not_used(self): + assert not self._already_used, ( + 'Cannot modify task since it is already been used.') + + def add_output(self, output): + self._assert_not_used() + self._outputs.append( + output if isinstance(output, TaskOutput) else TaskOutput(output)) + + def add_outputs(self, outputs): + self._assert_not_used() + if type(outputs) not in (list, tuple): + outputs = [outputs] + for output in outputs: + self.add_output(output) + + def set_step(self, step): + self._assert_not_used() + self._step = core.to_execution_step(step) + + def get_step(self): + if self._step is not None and self._step_with_setup is None: + init_nets, exit_nets = get_setup_nets( + Task.TASK_SETUP, [self._step], self) + if len(self._outputs) == 0: + output_net = core.Net("output_net") + self.add_output(output_net.ConstantFill( + [], 1, dtype=core.DataType.INT32, value=0)) + exit_nets.append(output_net) + self._step_with_setup = core.execution_step( + 'task', + [ + core.execution_step('task_init', init_nets), + self._step, + core.execution_step('task_exit', exit_nets), + ] + ) + return self._step_with_setup + + def outputs(self): + return self._outputs + + def output_names(self): + """ + Retrive the output names. + TODO(azzolini): make this schema-based. + """ + names = [] + for o in self._outputs: + names += o.names + return names + + def set_outputs(self, values, _fetch_func): + """ + Set output values. + TODO(azzolini): make this schema-based. + """ + offset = 0 + for o in self._outputs: + num = len(o.names) + o.set(values[offset:offset + num], _fetch_func) + offset += num + assert offset == len(values), 'Wrong number of output values.' + + def resolved_outputs(self): + return [output.get() for output in self._outputs] + + def _notify_used(self): + self.get_step() + self._already_used = True + + +class SetupNets(object): + """ + Allow to register a list of nets to be run at initialization + and finalization of Tasks or TaskGroups. + For example, let's say you have the following: + + init_net = core.Net('init') + my_val = init_net.ConstantFill([], 'my_val', value=0) + + net = core.Net('counter') + net.Add([my_val, net.Const(1),], [my_val]) + + with TaskGroup() as task_group: + with Node('trainer'): + my_task = Task(step=[net]) + + In order to have `init_net` run once before `net` runs for the + first time, you can do one of the following: + + net.add_object(Task.TASK_SETUP, SetupNets([init_net])) + + or + + net.add_object(TaskGroup.LOCAL_SETUP, SetupNets([init_net])) + + - With Task.TASK_SETUP, init_net will run once at my_task startup. + - With TaskGroup.LOCAL_SETUP, init_net will run once on node 'trainer', + before any task of the task group is run on that node. + + The same SetupNets object can be added to multiple nets. It will only + run once per Task/TaskGroup run. + """ + + def __init__(self, init_nets=None, exit_nets=None): + self.init_nets = init_nets + self.exit_nets = exit_nets + + def setup(self, init_net): + return self.init_nets + + def exit(self, exit_net): + return self.exit_nets diff --git a/caffe2/python/timeout_guard.py b/caffe2/python/timeout_guard.py new file mode 100644 index 000000000000..6ffef6da72b0 --- /dev/null +++ b/caffe2/python/timeout_guard.py @@ -0,0 +1,56 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import contextlib +import threading +import os +import time +import logging + + +''' +Sometimes CUDA devices can get stuck, 'deadlock'. In this case it is often +better just the kill the process automatically. Use this guard to set a +maximum timespan for a python call, such as RunNet(). If it does not complete +in time, process is killed. + +Example usage: + with timeout_guard.CompleteInTimeOrDie(10.0): + core.RunNet(...) +''' + + +class WatcherThread(threading.Thread): + + def __init__(self, timeout_secs): + threading.Thread.__init__(self) + self.timeout_secs = timeout_secs + self.completed = False + self.condition = threading.Condition() + self.daemon = True + + def run(self): + started = time.time() + self.condition.acquire() + while time.time() - started < self.timeout_secs and not self.completed: + self.condition.wait(self.timeout_secs - (time.time() - started)) + self.condition.release() + if not self.completed: + log = logging.getLogger("timeout_guard") + log.error("Call did not finish in time. Timeout:{}s".format( + self.timeout_secs + )) + os._exit(1) + + +@contextlib.contextmanager +def CompleteInTimeOrDie(timeout_secs): + watcher = WatcherThread(timeout_secs) + watcher.start() + yield + watcher.completed = True + watcher.condition.acquire() + watcher.condition.notify() + watcher.condition.release() diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index 1ce3d625d8f8..d967c08f0e19 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -25,6 +25,7 @@ RootFolder = C.root_folder Workspaces = C.workspaces BenchmarkNet = C.benchmark_net +is_asan = C.is_asan has_gpu_support = C.has_gpu_support if has_gpu_support: NumCudaDevices = C.num_cuda_devices @@ -201,8 +202,8 @@ def FeedBlob(name, arr, device_option=None): name = StringifyBlobName(name) if device_option is not None: return C.feed_blob(name, arr, StringfyProto(device_option)) - elif scope.DEVICESCOPE is not None: - return C.feed_blob(name, arr, StringfyProto(scope.DEVICESCOPE)) + elif scope.CurrentDeviceScope() is not None: + return C.feed_blob(name, arr, StringfyProto(scope.CurrentDeviceScope())) else: return C.feed_blob(name, arr) @@ -231,7 +232,7 @@ def FetchBlob(name): def GetNameScope(): """Return the current namescope string. To be used to fetch blobs""" - return scope.NAMESCOPE + return scope.CurrentNameScope() class _BlobDict(object): diff --git a/caffe2/queue/queue_ops.h b/caffe2/queue/queue_ops.h index 7f91919d7f09..7cadbb1e3e58 100644 --- a/caffe2/queue/queue_ops.h +++ b/caffe2/queue/queue_ops.h @@ -79,7 +79,6 @@ class CloseBlobsQueueOp final : public Operator { OperatorBase::Inputs()[0]->template Get>(); CHECK(queue); queue->close(); - queue.reset(); return true; } @@ -98,8 +97,8 @@ class SafeEnqueueBlobsOp final : public Operator { auto size = queue->getNumBlobs(); CAFFE_ENFORCE( OutputSize() == size + 1, - "Expected " + std::to_string(size + 1) + ", " + " got: " + - std::to_string(size)); + "Expected " + caffe2::to_string(size + 1) + ", " + " got: " + + caffe2::to_string(size)); bool status = queue->blockingWrite(this->Outputs()); Output(size)->Resize(); *Output(size)->template mutable_data() = !status; @@ -120,8 +119,8 @@ class SafeDequeueBlobsOp final : public Operator { auto size = queue->getNumBlobs(); CAFFE_ENFORCE( OutputSize() == size + 1, - "Expected " + std::to_string(size + 1) + ", " + " got: " + - std::to_string(size)); + "Expected " + caffe2::to_string(size + 1) + ", " + " got: " + + caffe2::to_string(size)); bool status = queue->blockingRead(this->Outputs()); Output(size)->Resize(); *Output(size)->template mutable_data() = !status; diff --git a/caffe2/sgd/adam_op.h b/caffe2/sgd/adam_op.h index 1da26dc00b0d..83deeb8731a5 100644 --- a/caffe2/sgd/adam_op.h +++ b/caffe2/sgd/adam_op.h @@ -23,7 +23,7 @@ void adam_update( float gi = g[i]; float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); - ng[i] = lr[0] * correction * mi / (sqrt(vi) + eps_hat); + ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); } } @@ -47,7 +47,7 @@ void adam_compute( float gi = g[i]; float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); - float ng = lr[0] * correction * mi / (sqrt(vi) + eps_hat); + float ng = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); nw[i] = w[i] + ng; } } @@ -154,7 +154,7 @@ class SparseAdamOp final : public Operator { float vi = moment2Out[idx] = moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); paramOut[idx] = - paramIn[idx] + lr[0] * correction * mi / (sqrt(vi) + epsilon_); + paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); } else { auto offsetI = i * block_size; diff --git a/caffe2/sgd/adam_op_gpu.cu b/caffe2/sgd/adam_op_gpu.cu index 7ba0856e354b..05ff1232f731 100644 --- a/caffe2/sgd/adam_op_gpu.cu +++ b/caffe2/sgd/adam_op_gpu.cu @@ -21,7 +21,7 @@ __global__ void AdamUpdate( float gi = g[i]; float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); - ng[i] = lr[0] * correction * mi / (sqrt(vi) + eps_hat); + ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); } } diff --git a/caffe2/utils/cpu_neon.h b/caffe2/utils/cpu_neon.h new file mode 100644 index 000000000000..5e66e4ecd969 --- /dev/null +++ b/caffe2/utils/cpu_neon.h @@ -0,0 +1,61 @@ +#ifndef CAFFE2_UTILS_CPU_NEON_H_ +#define CAFFE2_UTILS_CPU_NEON_H_ + +// Provides a variety of ARM NEON-specific utility functions +#ifdef __ARM_NEON__ +#include + +namespace caffe2 { + +inline int divUp(int a, int b) { + return (a + b - 1) / b; +} + +inline int roundUp(int a, int b) { + return divUp(a, b) * b; +} + +template +inline bool isPointerAligned(T* p, size_t align) { + return (reinterpret_cast(p) % align == 0); +} + +inline float32x4_t vert_sum_f32(float32x4_t v0, + float32x4_t v1, + float32x4_t v2, + float32x4_t v3) { + v0 = vaddq_f32(v0, v1); + v2 = vaddq_f32(v2, v3); + return vaddq_f32(v0, v2); +} + +inline float horizontal_sum_f32(float32x4_t v0, + float32x4_t v1, + float32x4_t v2, + float32x4_t v3) { + v0 = vert_sum_f32(v0, v1, v2, v3); + float32x2_t v = vadd_f32(vget_high_f32(v0), vget_low_f32(v0)); + return vget_lane_f32(vpadd_f32(v, v), 0); +} + +// Load/store functions that assume alignment + +inline float32x4_t vld1q_f32_aligned(const float* p) { + return vld1q_f32((const float*) + __builtin_assume_aligned(p, sizeof(float32x4_t))); +} + +inline void vst1q_f32_aligned(float* p, float32x4_t v) { + vst1q_f32((float*) __builtin_assume_aligned(p, sizeof(float32x4_t)), v); +} + +inline void vst4_u8_aligned(uint8_t* p, uint8x8x4_t v) { + vst4_u8((uint8_t*) + __builtin_assume_aligned(p, sizeof(uint8x8x4_t)), v); +} + +} // namespace caffe2 + +#endif // __ARM_NEON__ + +#endif // CAFFE2_UTILS_CPU_NEON_H_ diff --git a/caffe2/utils/fixed_divisor.h b/caffe2/utils/fixed_divisor.h new file mode 100644 index 000000000000..a6cc67040205 --- /dev/null +++ b/caffe2/utils/fixed_divisor.h @@ -0,0 +1,146 @@ +#ifndef CAFFE2_UTILS_FIXED_DIVISOR_H_ +#define CAFFE2_UTILS_FIXED_DIVISOR_H_ + +#ifdef __ARM_NEON__ +#include +#endif +#include + +namespace caffe2 { + +namespace detail { + +inline uint32_t mulHi(uint32_t x, uint32_t y) { + uint64_t v = (uint64_t) x * (uint64_t) y; + return (uint32_t) (v >> 32); +} + +} + +// Utility class for quickly calculating quotients and remainders for +// a known integer divisor +template +class FixedDivisor { +}; + +template <> +class FixedDivisor { + public: + typedef int Type; + + FixedDivisor(int d) : d_(d) { + calcSignedMagic(); + } + + /// Calculates `q = n / d`. + inline int div(int n) const { + return (int) (detail::mulHi(magic_, n) >> shift_); + } + + /// Calculates `r = n % d`. + inline int mod(int n) const { + return n - d_ * div(n); + } + + /// Calculates `q = n / d` and `r = n % d` together. + inline void divMod(int n, int& q, int& r) const { + const int quotient = div(n); + q = quotient; + r = n - d_ * quotient; + } + +#ifdef __ARM_NEON__ + inline void divModVector(int32x4_t n, + int32x4_t& q, + int32x4_t& r) const { + int32x2_t loQ; + int32x2_t loR; + divModVector(vget_low_s32(n), loQ, loR); + + int32x2_t hiQ; + int32x2_t hiR; + divModVector(vget_high_s32(n), hiQ, hiR); + + q = vcombine_s32(loQ, hiQ); + r = vcombine_s32(loR, hiR); + } + + inline void divModVector(int32x2_t n, + int32x2_t& q, + int32x2_t& r) const { + q = divVector(n); + + // r = n - d * q + r = vsub_s32(n, vmul_s32(vdup_n_s32(d_), q)); + } + + // Calculates `q1 = v1 / d, q2 = v2 / d` using NEON + inline int32x2_t divVector(int32x2_t v) const { + uint32x2_t vUnsigned = vreinterpret_u32_s32(v); + + uint32x2_t resultUnsigned = + vmovn_u64( + vshlq_u64( + vmull_u32(vUnsigned, vdup_n_u32(magic_)), + vdupq_n_s64(-32 - shift_))); + + return vreinterpret_s32_u32(resultUnsigned); + } +#endif + + private: + /** + Calculates magic multiplicative value and shift amount for + calculating `q = n / d` for signed 32-bit integers. + Implementation taken from Hacker's Delight section 10. + `d` cannot be in [-1, 1]. + */ + void calcSignedMagic() { + const unsigned int two31 = 0x80000000; + + unsigned int ad = std::abs(d_); + unsigned int t = two31 + ((unsigned int) d_ >> 31); + unsigned int anc = t - 1 - t % ad; // Absolute value of nc. + unsigned int p = 31; // Init. p. + unsigned int q1 = two31 / anc; // Init. q1 = 2**p/|nc|. + unsigned int r1 = two31 - q1 * anc; // Init. r1 = rem(2**p, |nc|). + unsigned int q2 = two31 / ad; // Init. q2 = 2**p/|d|. + unsigned int r2 = two31 - q2 * ad; // Init. r2 = rem(2**p, |d|). + unsigned int delta = 0; + + do { + p = p + 1; + q1 = 2 * q1; // Update q1 = 2**p/|nc|. + r1 = 2 * r1; // Update r1 = rem(2**p, |nc|). + + if (r1 >= anc) { // (Must be an unsigned + q1 = q1 + 1; // comparison here). + r1 = r1 - anc; + } + + q2 = 2 * q2; // Update q2 = 2**p/|d|. + r2 = 2 * r2; // Update r2 = rem(2**p, |d|). + + if (r2 >= ad) { // (Must be an unsigned + q2 = q2 + 1; // comparison here). + r2 = r2 - ad; + } + + delta = ad - r2; + } while (q1 < delta || (q1 == delta && r1 == 0)); + + magic_ = q2 + 1; + if (d_ < 0) { + magic_ = -magic_; + } + shift_ = p - 32; + } + + int d_; + int magic_; + int shift_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_UTILS_FIXED_DIVISOR_H_ diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index a1d0ec88c06f..497ee944f6e2 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -239,6 +239,16 @@ void Col2im( T* data_im, Context* context); +// Applies a per-channel bias value to each channel of the input +// image. image_size is H * W +template +void BiasCHW( + const T* bias, + const int bias_channels, + const int image_size, + T* image, + Context* context); + template void CopyMatrix(const size_t item_size, const int M, const int N, const void* A, const int lda, void* B, const int ldb, Context* context); @@ -246,6 +256,18 @@ void CopyMatrix(const size_t item_size, const int M, const int N, const void* A, uint32_t randomNumberSeed(); +// Function uses casting from int to unsigned to compare if value of +// parameter a is greater or equal to zero and lower than value of +// parameter b. The b parameter is of type signed and is always +// positive, +// therefore its value is always lower than 0x800... where casting +// negative value of a parameter converts it to value higher than +// 0x800... +// The casting allows to use one condition instead of two. +inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { + return static_cast(a) < static_cast(b); +} + } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index 009d6a394f0d..855eda10c3e4 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -22,6 +22,7 @@ #endif // CAFFE2_USE_MKL #include "caffe2/utils/math.h" +#include "caffe2/utils/cpu_neon.h" #include "caffe2/core/context.h" #include "Eigen/Core" #include "Eigen/Dense" @@ -101,10 +102,79 @@ void Gemm( } } +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const float beta, + const int ldb, + float* C, + const int ldc, + CPUContext*) { + using OuterStride = Eigen::OuterStride; + using StridedMap = Eigen::Map; + using ConstStridedMap = Eigen::Map; + auto C_mat = StridedMap(C, N, M, OuterStride(ldc)); + if (beta == 0) { + C_mat.setZero(); + } else { + C_mat *= beta; + } + switch (TransA) { + case CblasNoTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += + alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) * + ConstStridedMap(A, K, M, OuterStride(lda))); + return; + case CblasTrans: + C_mat.noalias() += + alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() * + ConstStridedMap(A, K, M, OuterStride(lda))); + return; + default: + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + } + } + case CblasTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += + alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) * + ConstStridedMap(A, M, K, OuterStride(lda)).transpose()); + return; + case CblasTrans: + C_mat.noalias() += + alpha * (ConstStridedMap(B, K, N, OuterStride(ldb)).transpose() * + ConstStridedMap(A, M, K, OuterStride(lda)).transpose()); + return; + default: + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + } + } + default: + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransA"; + } +} + template <> void Gemv( - const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, - const float* A, const float* x, const float beta, float* y, + const CBLAS_TRANSPOSE TransA, + const int M, + const int N, + const float alpha, + const float* A, + const float* x, + const float beta, + float* y, CPUContext* context) { EigenVectorMap y_vec(y, TransA == CblasNoTrans ? M : N); if (beta == 0) { @@ -490,11 +560,15 @@ DEFINE_BROADCAST_BINARY_FUNCTION(Div, /) #undef DEFINE_BROADCAST_BINARY_FUNCTION #undef DELEGATE_BROADCAST_BINARY_FUNCTION -#define CAFFE2_SPECIALIZED_SET(T) \ -template <> \ -void Set(const int N, const T alpha, T *Y, \ - CPUContext* context) { \ - EigenVectorMap(Y, N).setConstant(alpha); \ +#define CAFFE2_SPECIALIZED_SET(T) \ +template <> \ +void Set(const int N, const T alpha, T *Y, \ + CPUContext* context) { \ + if (alpha == (T) 0) { \ + memset(Y, 0, N * sizeof(T)); \ + } else { \ + EigenVectorMap(Y, N).setConstant(alpha); \ + } \ } CAFFE2_SPECIALIZED_SET(float); @@ -610,18 +684,6 @@ void Select( } } -// Function uses casting from int to unsigned to compare if value of -// parameter a is greater or equal to zero and lower than value of -// parameter b. The b parameter is of type signed and is always -// positive, -// therefore its value is always lower than 0x800... where casting -// negative value of a parameter converts it to value higher than -// 0x800... -// The casting allows to use one condition instead of two. -inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { - return static_cast(a) < static_cast(b); -} - template <> void Im2col( const float* data_im, @@ -946,6 +1008,90 @@ void Col2im( } } +template <> +void BiasCHW( + const float* bias, + const int bias_channels, + const int image_size, + float* image, + CPUContext* context) { + // Sum the per-channel bias into every image plane + for (int c = 0; c < bias_channels; ++c) { + float b = bias[c]; + +#ifdef __ARM_NEON__ + float32x4_t vBias = vdupq_n_f32(b); + + // We give alignment hints for additional speed, so handle the + // non-vectorizable prologue separately + constexpr int kVecSizeInFloat = sizeof(float32x4_t) / sizeof(float); + + // FIXME: if input < kVecSizeInFloat, can't vectorize at all + + int prologue = + kVecSizeInFloat - + // remainder in floats + (((uintptr_t) image) % (sizeof(float32x4_t))) / sizeof(float); + + int i = 0; + // Prologue loop + for (; i < prologue; ++i) { + image[i] += b; + } + + // The loop is manually unrolled by 8 + constexpr int kUnroll = 8; + constexpr int kFloatsPerLoop = kUnroll * kVecSizeInFloat; + + int remainder = image_size - prologue; + int vectorizable = prologue + (remainder / kFloatsPerLoop) * kFloatsPerLoop; + + // Vectorizable body + for (; i < vectorizable; i += kFloatsPerLoop) { + // Manually unrolled + float32x4_t v0 = vld1q_f32_aligned(image + i + 0); + float32x4_t v1 = vld1q_f32_aligned(image + i + 4); + float32x4_t v2 = vld1q_f32_aligned(image + i + 8); + float32x4_t v3 = vld1q_f32_aligned(image + i + 12); + float32x4_t v4 = vld1q_f32_aligned(image + i + 16); + float32x4_t v5 = vld1q_f32_aligned(image + i + 20); + float32x4_t v6 = vld1q_f32_aligned(image + i + 24); + float32x4_t v7 = vld1q_f32_aligned(image + i + 28); + + v0 = vaddq_f32(v0, vBias); + v1 = vaddq_f32(v1, vBias); + v2 = vaddq_f32(v2, vBias); + v3 = vaddq_f32(v3, vBias); + v4 = vaddq_f32(v4, vBias); + v5 = vaddq_f32(v5, vBias); + v6 = vaddq_f32(v6, vBias); + v7 = vaddq_f32(v7, vBias); + + vst1q_f32_aligned(image + i + 0, v0); + vst1q_f32_aligned(image + i + 4, v1); + vst1q_f32_aligned(image + i + 8, v2); + vst1q_f32_aligned(image + i + 12, v3); + vst1q_f32_aligned(image + i + 16, v4); + vst1q_f32_aligned(image + i + 20, v5); + vst1q_f32_aligned(image + i + 24, v6); + vst1q_f32_aligned(image + i + 28, v7); + } + + // Non-vectorizable epilogue + for (; i < image_size; ++i) { + image[i] += b; + } +#else + // Non-NEON CPU implementation + for (int i = 0; i < image_size; ++i) { + image[i] += b; + } +#endif // __ARM_NEON__ + + image += image_size; + } +} + template <> void CopyMatrix( const size_t itemsize, const int M, const int N, const void* A, diff --git a/caffe2/utils/mkl/sgemm_pack.h b/caffe2/utils/mkl/sgemm_pack.h index cdfd63ca804e..5d45f8455910 100644 --- a/caffe2/utils/mkl/sgemm_pack.h +++ b/caffe2/utils/mkl/sgemm_pack.h @@ -6,8 +6,8 @@ namespace caffe2 { namespace mkl { struct MKLPackedMatrix { - char identifier_; - char trans_; + CBLAS_IDENTIFIER identifier_; + CBLAS_TRANSPOSE trans_; int m_; int n_; int k_; @@ -16,8 +16,8 @@ struct MKLPackedMatrix { float* data_ = nullptr; MKLPackedMatrix( - const char identifier, - const char trans, + const CBLAS_IDENTIFIER identifier, + const CBLAS_TRANSPOSE trans, const int m, const int n, const int k, @@ -31,14 +31,15 @@ struct MKLPackedMatrix { k_(k), alpha_(alpha), ld_(ld) { - data_ = sgemm_alloc(&identifier, &m, &n, &k); + data_ = cblas_sgemm_alloc(identifier, m, n, k); CAFFE_ENFORCE(data_, "MKL runtime error: cannot allocate sgemm memory."); - sgemm_pack(&identifier, &trans, &m, &n, &k, &alpha, src, &ld, data_); + cblas_sgemm_pack( + CblasRowMajor, identifier, trans, m, n, k, alpha, src, ld, data_); } ~MKLPackedMatrix() { if (data_) { - sgemm_free(data_); + cblas_sgemm_free(data_); } } }; diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h index 2ce622361dde..78a4b7586561 100644 --- a/caffe2/utils/proto_utils.h +++ b/caffe2/utils/proto_utils.h @@ -213,7 +213,9 @@ template Argument MakeArgument(const string& name, const T& value); template -void AddArgument(const string& name, const T& value, OperatorDef* def); +inline void AddArgument(const string& name, const T& value, OperatorDef* def) { + GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value)); +} } // namespace caffe2 diff --git a/caffe2/utils/threadpool/ThreadPool.cc b/caffe2/utils/threadpool/ThreadPool.cc new file mode 100644 index 000000000000..12bdf8ac4368 --- /dev/null +++ b/caffe2/utils/threadpool/ThreadPool.cc @@ -0,0 +1,231 @@ +#include "caffe2/utils/threadpool/ThreadPool.h" +#include "caffe2/core/logging.h" + +#if CAFFE2_THREADPOOL_MOBILE + +namespace caffe2 { + +// Default smallest amount of work that will be partitioned between +// multiple threads; the runtime value is configurable +constexpr size_t kDefaultMinWorkSize = 80; + +ThreadPool::ThreadPool(int numThreads) + : fn_(nullptr), + workItemsPending_(0), + currentWorkId_(0), + threadsReady_(0), + minWorkSize_(kDefaultMinWorkSize) { + std::lock_guard guard(mutex_); + + // All worker threads (and the main thread) have a ThreadInfo + for (auto i = 0; i < numThreads; ++i) { + threadInfo_.emplace_back( + std::unique_ptr(new ThreadInfo(i, numThreads))); + } + + // The first ThreadInfo is for the main thread + for (auto i = 1; i < numThreads; ++i) { + auto pInfo = &(threadInfo_[i]); + auto fn = [pInfo, this, i]() { + (*pInfo)->threadMain(i, this); + }; + + threads_.emplace_back(std::thread(std::move(fn))); + } +} + +ThreadPool::~ThreadPool() { + { + std::lock_guard guard(mutex_); + for (auto& info : threadInfo_) { + info->wantExit_ = true; + } + } + + threadStartMonitor_.notify_all(); + + // Wait on all threads to exit + for (auto& thread : threads_) { + thread.join(); + } +} + +int +ThreadPool::getNumThreads() const { + std::lock_guard guard(executionMutex_); + + return threadInfo_.size(); +} + + // Sets the minimum work size (range) for which to invoke the + // threadpool; work sizes smaller than this will just be run on the + // main (calling) thread +void +ThreadPool::setMinWorkSize(size_t size) { + std::lock_guard guard(executionMutex_); + + minWorkSize_ = size; +} + +void +ThreadPool::run(const std::function& fn, size_t range) { + std::lock_guard guard(executionMutex_); + + // If there are no worker threads, or if the range is too small (too + // little work), just run locally + if (threads_.size() == 0 || range < minWorkSize_) { + for (size_t i = 0; i < range; ++i) { + fn(0, i); + } + + return; + } + + // Set up thread state + { + std::unique_lock lock(mutex_); + + // We've guaranteed that all threads have finished work for the + // previous round, but we don't want threads to read new work + // information out of order. Wait for all of the old threads to + // check in first + while (threadsReady_ < threads_.size()) { + threadReadyMonitor_.wait(lock); + } + + // Our threads are ready, and are waiting for us to start them. + threadsReady_ = 0; + + fn_ = &fn; + + auto numThreads = threadInfo_.size(); + size_t workUnitsPerThread = (numThreads + range - 1) / numThreads; + + for (size_t i = 0; i < numThreads; ++i) { + auto& threadInfo = threadInfo_[i]; + + threadInfo->rangeStart_ = std::min(i * workUnitsPerThread, range); + threadInfo->rangeEnd_ = std::min((i + 1) * workUnitsPerThread, range); + threadInfo->rangeLength_ = + threadInfo->rangeEnd_ - threadInfo->rangeStart_; + } + + workItemsPending_ = range; + ++currentWorkId_; + } + + // Wake all worker threads + threadStartMonitor_.notify_all(); + + // We participate as well + bool done = threadInfo_[0]->runAndSteal(0, this); + + // This thread may have been the one to finish all the work + if (!done) { + // Wait until we get signalled back + { + std::unique_lock lock(mutex_); + while (workItemsPending_.load() > 0) { + threadDoneMonitor_.wait(lock); + } + } + } +} + +void +ThreadInfo::threadMain(int threadId, ThreadPool* pool) { + long lastProcessedWorkId = 0; + + while (true) { + { + // Kick off + std::unique_lock lock(pool->mutex_); + int numAtBarrier = ++(pool->threadsReady_); + // numThreads includes main thread, we only care about the # of + // worker threads here + if (numAtBarrier == (numThreads_ - 1)) { + pool->threadReadyMonitor_.notify_one(); + } + + // Wait on main to give us new work + while (!wantExit_ && pool->currentWorkId_ <= lastProcessedWorkId) { + pool->threadStartMonitor_.wait(lock); + } + + // Whether or not we actually do some work, this is the new work + // item we're handling + lastProcessedWorkId = pool->currentWorkId_; + } + + if (wantExit_) { + return; + } + + bool shouldSignal = runAndSteal(threadId, pool); + + if (shouldSignal) { + std::lock_guard guard(pool->mutex_); + pool->threadDoneMonitor_.notify_one(); + } + } +} + +bool +ThreadInfo::runAndSteal(int threadId, ThreadPool* pool) { + auto lambdaFunctionToRun = pool->fn_; + auto localItemsCompleted = 0; + + /* Process thread's own range of items */ + auto curItem = rangeStart_; + while (true) { + auto curRangeLength = --rangeLength_; // atomic + + if (curRangeLength < 0) { + // someone stole all of our work + break; + } + + (*lambdaFunctionToRun)(threadId, curItem); + + ++curItem; + ++localItemsCompleted; + } + + /* Done, now look for other threads' items to steal */ + for (auto i = (threadId_ + 1) % numThreads_; + i != threadId_; + i = (i + 1) % numThreads_) { + auto& otherThread = pool->threadInfo_[i]; + + while (true) { + auto curRangeLength = --(otherThread->rangeLength_); // atomic + + if (curRangeLength < 0) { + break; + } + + // We're successfully stealing a work item from the other thread + auto itemId = --(otherThread->rangeEnd_); // atomic + + (*lambdaFunctionToRun)(threadId, itemId); + ++localItemsCompleted; + } + } + + if (localItemsCompleted > 0) { + auto numRemaining = + (pool->workItemsPending_ -= localItemsCompleted); // atomic + DCHECK_GE(numRemaining, 0); + + if (numRemaining == 0) { + // We were the last thread to finish all work + return true; + } + } + + return false; +} + +} // namespace caffe2 + +#endif // CAFFE2_THREADPOOL_MOBILE diff --git a/caffe2/utils/threadpool/ThreadPool.h b/caffe2/utils/threadpool/ThreadPool.h new file mode 100644 index 000000000000..d36928e0ce9e --- /dev/null +++ b/caffe2/utils/threadpool/ThreadPool.h @@ -0,0 +1,143 @@ +#ifndef CAFFE2_UTILS_THREADPOOL_H_ +#define CAFFE2_UTILS_THREADPOOL_H_ + +#include "ThreadPoolCommon.h" + +#ifndef CAFFE2_THREADPOOL_MOBILE +#error "mobile build state not defined" +#endif + +// ThreadPool only used in mobile builds at the moment +#if CAFFE2_THREADPOOL_MOBILE + +#include +#include +#include +#include +#include +#include +#include + +// +// A work-stealing threadpool loosely based off of pthreadpool +// + +namespace caffe2 { + +struct ThreadPool; + +struct __attribute__((__aligned__(64))) ThreadInfo { + ThreadInfo(int threadId, int numThreads) : + rangeStart_(0), + rangeEnd_(0), + rangeLength_(0), + wantExit_(false), + threadId_(threadId), + numThreads_(numThreads) { + } + + // Entry point for all worker threads + void threadMain(int threadId, ThreadPool* pool); + + // Runs a task, and when we're done with our local queue, steal from + // neighbors. + // Returns true if all work is done (we were the last thread to do + // work) + bool runAndSteal(int threadId, ThreadPool* pool); + + // Index of first element in the work range. + // Before processing a new element the owning worker thread + // increments this value. + long rangeStart_; + + // Index of the element after the last element of the work range. + // Before processing a new element the stealing worker thread + // decrements this value. + std::atomic rangeEnd_; + + // The number of elements in the work range. + // Due to race conditions range_length <= range_end - range_start. + // The owning worker thread must decrement this value before + // incrementing @a range_start. + // The stealing worker thread must decrement this value before + // decrementing @a range_end. + std::atomic rangeLength_; + + // Should this thread exit? + bool wantExit_; + + // Our thread index + int threadId_; + + // How many threads are there in total? + int numThreads_; +}; + +class __attribute__((__aligned__(64))) ThreadPool { + public: + ThreadPool(int numThreads); + ~ThreadPool(); + + // Returns the number of threads currently in use + int getNumThreads() const; + + // Sets the minimum work size (range) for which to invoke the + // threadpool; work sizes smaller than this will just be run on the + // main (calling) thread + void setMinWorkSize(size_t size); + + // Called to schedule work on the threadpool + void run(const std::function& fn, size_t range); + + protected: + friend struct ThreadInfo; + + // What we are currently working on + const std::function* fn_; + + // How many work items are outstanding? When this reaches 0, our + // main thread is resumed + std::atomic workItemsPending_; + + // Current work ID that we're running; sequentially increments + long currentWorkId_; + + // Mutex that guards all monitors and state updates + std::mutex mutex_; + + // Main thread waits on this before running new work, to make sure + // that all worker threads have looped back around to await new work + std::condition_variable threadReadyMonitor_; + + // All worker threads wait on this to make sure that they have work + // available for processing + std::condition_variable threadStartMonitor_; + + // Main thread waits on this before returning to the thread pool + // caller; note that we don't actually wait on the worker threads + // saying that they're all done (woken up); we only check when the + // thread pool is called again + std::condition_variable threadDoneMonitor_; + + // How many threads are ready to process new work? + size_t threadsReady_; + + // The first entry is always for the main thread + std::vector> threadInfo_; + + // Set of threads that we are managing + std::vector threads_; + + // What's the minimum work size for using the threadpool? + size_t minWorkSize_; + + // Mutex that ensures that only one user call to the ThreadPool is + // outstanding + mutable std::mutex executionMutex_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_THREADPOOL_MOBILE + +#endif // CAFFE2_UTILS_THREADPOOL_H_ diff --git a/caffe2/utils/threadpool/ThreadPoolCommon.h b/caffe2/utils/threadpool/ThreadPoolCommon.h new file mode 100644 index 000000000000..1540712af9c6 --- /dev/null +++ b/caffe2/utils/threadpool/ThreadPoolCommon.h @@ -0,0 +1,26 @@ +#ifndef CAFFE2_UTILS_THREADPOOL_COMMON_H_ +#define CAFFE2_UTILS_THREADPOOL_COMMON_H_ + +#ifdef __APPLE__ +#include +#endif + +// caffe2 depends upon NNPACK, which depends upon this threadpool, so +// unfortunately we can't reference core/common.h here + +// This is copied from core/common.h's definition of CAFFE2_MOBILE +// Define enabled when building for iOS or Android devices +#if !defined(CAFFE2_THREADPOOL_MOBILE) +#if defined(__ANDROID__) +#define CAFFE2_ANDROID 1 +#define CAFFE2_THREADPOOL_MOBILE 1 +#elif (defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define CAFFE2_IOS 1 +#define CAFFE2_THREADPOOL_MOBILE 1 +#else +#define CAFFE2_THREADPOOL_MOBILE 0 +#endif // ANDROID / IOS +#endif // CAFFE2_THREADPOOL_MOBILE + +#endif // CAFFE2_UTILS_THREADPOOL_COMMON_H_ diff --git a/caffe2/utils/threadpool/pthreadpool.cc b/caffe2/utils/threadpool/pthreadpool.cc new file mode 100644 index 000000000000..c22117936277 --- /dev/null +++ b/caffe2/utils/threadpool/pthreadpool.cc @@ -0,0 +1,169 @@ +/* Standard C headers */ +#include +#include +#include +#include +#include + +/* POSIX headers */ +#include +#include + +/* Library header */ +#include "caffe2/core/logging.h" +#include "caffe2/utils/fixed_divisor.h" +#include "caffe2/utils/threadpool/pthreadpool.h" + +#if CAFFE2_THREADPOOL_MOBILE + +static inline size_t divide_round_up(size_t dividend, size_t divisor) { + if (dividend % divisor == 0) { + return dividend / divisor; + } else { + return dividend / divisor + 1; + } +} + +static inline size_t min(size_t a, size_t b) { + return a < b ? a : b; +} + +struct compute_1d_tiled_context { + pthreadpool_function_1d_tiled_t function; + void* argument; + size_t range; + size_t tile; +}; + +static void compute_1d_tiled(const struct compute_1d_tiled_context* context, size_t linear_index) { + const size_t tile_index = linear_index; + const size_t index = tile_index * context->tile; + const size_t tile = min(context->tile, context->range - index); + context->function(context->argument, index, tile); +} + +void pthreadpool_compute_1d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_1d_tiled_t function, + void* argument, + size_t range, + size_t tile) +{ + if (threadpool == NULL) { + /* No thread pool provided: execute function sequentially on the calling thread */ + for (size_t i = 0; i < range; i += tile) { + function(argument, i, min(range - i, tile)); + } + } else { + /* Execute in parallel on the thread pool using linearized index */ + const size_t tile_range = divide_round_up(range, tile); + struct compute_1d_tiled_context context = { + .function = function, + .argument = argument, + .range = range, + .tile = tile + }; + pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range); + } +} + +struct compute_2d_context { + pthreadpool_function_2d_t function; + void* argument; + caffe2::FixedDivisor range_j; +}; + +static void compute_2d(const struct compute_2d_context* context, size_t linear_index) { + DCHECK_LE(linear_index, std::numeric_limits::max()); + + int q; + int r; + context->range_j.divMod((int) linear_index, q, r); + context->function(context->argument, q, r); +} + +void pthreadpool_compute_2d( + struct pthreadpool* threadpool, + pthreadpool_function_2d_t function, + void* argument, + size_t range_i, + size_t range_j) +{ + if (threadpool == NULL) { + /* No thread pool provided: execute function sequentially on the calling thread */ + for (size_t i = 0; i < range_i; i++) { + for (size_t j = 0; j < range_j; j++) { + function(argument, i, j); + } + } + } else { + DCHECK_LE(range_i * range_j, (size_t) std::numeric_limits::max()); + /* Execute in parallel on the thread pool using linearized index */ + struct compute_2d_context context = { + .function = function, + .argument = argument, + .range_j = caffe2::FixedDivisor(range_j) + }; + pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j); + } +} + +struct compute_2d_tiled_context { + pthreadpool_function_2d_tiled_t function; + void* argument; + caffe2::FixedDivisor tile_range_j; + size_t range_i; + size_t range_j; + size_t tile_i; + size_t tile_j; +}; + +static void compute_2d_tiled(const struct compute_2d_tiled_context* context, size_t linear_index) { + int q; + int r; + + context->tile_range_j.divMod(linear_index, q, r); + const size_t max_tile_i = context->tile_i; + const size_t max_tile_j = context->tile_j; + const size_t index_i = q * max_tile_i; + const size_t index_j = r * max_tile_j; + const size_t tile_i = min(max_tile_i, context->range_i - index_i); + const size_t tile_j = min(max_tile_j, context->range_j - index_j); + context->function(context->argument, index_i, index_j, tile_i, tile_j); +} + +void pthreadpool_compute_2d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_2d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j) +{ + if (threadpool == NULL) { + /* No thread pool provided: execute function sequentially on the calling thread */ + for (size_t i = 0; i < range_i; i += tile_i) { + for (size_t j = 0; j < range_j; j += tile_j) { + function(argument, i, j, min(range_i - i, tile_i), min(range_j - j, tile_j)); + } + } + } else { + /* Execute in parallel on the thread pool using linearized index */ + const size_t tile_range_i = divide_round_up(range_i, tile_i); + const size_t tile_range_j = divide_round_up(range_j, tile_j); + DCHECK_LE(tile_range_i * tile_range_j, (size_t) std::numeric_limits::max()); + struct compute_2d_tiled_context context = { + .function = function, + .argument = argument, + .tile_range_j = caffe2::FixedDivisor(tile_range_j), + .range_i = range_i, + .range_j = range_j, + .tile_i = tile_i, + .tile_j = tile_j + }; + pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j); + } +} + +#endif // CAFFE2_THREADPOOL_MOBILE diff --git a/caffe2/utils/threadpool/pthreadpool.h b/caffe2/utils/threadpool/pthreadpool.h new file mode 100644 index 000000000000..2f7e54eed780 --- /dev/null +++ b/caffe2/utils/threadpool/pthreadpool.h @@ -0,0 +1,111 @@ +// pthreadpool header from https://github.com/Maratyszcza/pthreadpool +// for NNPACK +#ifndef CAFFE2_UTILS_PTHREADPOOL_H_ +#define CAFFE2_UTILS_PTHREADPOOL_H_ + +#include "ThreadPoolCommon.h" + +#ifndef CAFFE2_THREADPOOL_MOBILE +#error "mobile build state not defined" +#endif + +// ThreadPool only used in mobile builds at the moment +#if CAFFE2_THREADPOOL_MOBILE + +#include // for size_t + +typedef struct pthreadpool* pthreadpool_t; + +typedef void (*pthreadpool_function_1d_t)(void*, size_t); +typedef void (*pthreadpool_function_1d_tiled_t)(void*, size_t, size_t); +typedef void (*pthreadpool_function_2d_t)(void*, size_t, size_t); +typedef void (*pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_function_3d_t)(void*, size_t, size_t, size_t); + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Creates a thread pool with the specified number of threads. + * + * @param[in] threads_count The number of threads in the thread pool. + * A value of 0 has special interpretation: it creates a thread for each + * processor core available in the system. + * + * @returns A pointer to an opaque thread pool object. + * On error the function returns NULL and sets errno accordingly. + */ +pthreadpool_t pthreadpool_create(size_t threads_count); + +/** + * Queries the number of threads in a thread pool. + * + * @param[in] threadpool The thread pool to query. + * + * @returns The number of threads in the thread pool. + */ +size_t pthreadpool_get_threads_count(pthreadpool_t threadpool); + + +/** + * Processes items in parallel using threads from a thread pool. + * + * When the call returns, all items have been processed and the thread pool is + * ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param[in] threadpool The thread pool to use for parallelisation. + * @param[in] function The function to call for each item. + * @param[in] argument The first argument passed to the @a function. + * @param[in] items The number of items to process. The @a function + * will be called once for each item. + */ +void pthreadpool_compute_1d( + pthreadpool_t threadpool, + pthreadpool_function_1d_t function, + void* argument, + size_t range); + +void pthreadpool_compute_1d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_1d_tiled_t function, + void* argument, + size_t range, + size_t tile); + +void pthreadpool_compute_2d( + pthreadpool_t threadpool, + pthreadpool_function_2d_t function, + void* argument, + size_t range_i, + size_t range_j); + +void pthreadpool_compute_2d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_2d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j); + +/** + * Terminates threads in the thread pool and releases associated resources. + * + * @warning Accessing the thread pool after a call to this function constitutes + * undefined behaviour and may cause data corruption. + * + * @param[in,out] threadpool The thread pool to destroy. + */ +void pthreadpool_destroy(pthreadpool_t threadpool); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif // CAFFE2_THREADPOOL_MOBILE + +#endif // CAFFE2_UTILS_PTHREADPOOL_H_ diff --git a/caffe2/utils/threadpool/pthreadpool_impl.cc b/caffe2/utils/threadpool/pthreadpool_impl.cc new file mode 100644 index 000000000000..f6bbb3700b36 --- /dev/null +++ b/caffe2/utils/threadpool/pthreadpool_impl.cc @@ -0,0 +1,26 @@ +#include "caffe2/utils/threadpool/pthreadpool.h" +#include "caffe2/utils/threadpool/pthreadpool_impl.h" +#include "caffe2/utils/threadpool/ThreadPool.h" + +#if CAFFE2_THREADPOOL_MOBILE + +// +// External API +// + +void pthreadpool_compute_1d(struct pthreadpool* threadpool, + pthreadpool_function_1d_t function, + void* argument, + size_t range) { + threadpool->pool_->run( + [function, argument](int threadId, size_t workId) { + function(argument, workId); + }, + range); +} + +size_t pthreadpool_get_threads_count(struct pthreadpool* threadpool) { + return threadpool->pool_->getNumThreads(); +} + +#endif // CAFFE2_THREADPOOL_MOBILE diff --git a/caffe2/utils/threadpool/pthreadpool_impl.h b/caffe2/utils/threadpool/pthreadpool_impl.h new file mode 100644 index 000000000000..ecc82b170829 --- /dev/null +++ b/caffe2/utils/threadpool/pthreadpool_impl.h @@ -0,0 +1,30 @@ +#ifndef CAFFE2_UTILS_PTHREADPOOL_IMPL_H_ +#define CAFFE2_UTILS_PTHREADPOOL_IMPL_H_ + +#include "ThreadPoolCommon.h" + +#ifndef CAFFE2_THREADPOOL_MOBILE +#error "mobile build state not defined" +#endif + +#if CAFFE2_THREADPOOL_MOBILE + +namespace caffe2 { + +struct ThreadPool; + +} // namespace caffe2 + +extern "C" { + +// Wrapper for the caffe2 threadpool for the usage of NNPACK +struct pthreadpool { + pthreadpool(caffe2::ThreadPool* pool) : pool_(pool) {} + caffe2::ThreadPool* pool_; +}; + +} // extern "C" + +#endif // CAFFE2_THREADPOOL_MOBILE + +#endif // CAFFE2_UTILS_PTHREADPOOL_IMPL_H_