diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index 22c63ecd9822..d9fe93fec921 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -274,6 +274,264 @@ class CheckDatasetConsistencyOp : public Operator { TreeIterator iterator_; }; +/** + * Simple wrapper class allowing an easy traversal of the tensors representing + * the hirerarchical structure. + */ +class TreeWalker { + public: + TreeWalker(const vector& inputs, TreeCursor& cursor) + : inputs_(inputs), cursor_(cursor), sizes_(cursor.it.numOffsetFields()) { + if (cursor.offsets.empty()) { + cursor.offsets.assign(cursor.it.numOffsetFields(), 0); + } + + for (int fieldId = 0; fieldId < cursor_.it.fields().size(); ++fieldId) { + fields_.emplace_back(*this, fieldId); + } + + gatherLengthData(); + + gatherSizeLimits(); + + // The invariant we hold is that we are always one step ahead + advance(); + } + + // Returns the number of records in a dataset + inline TOffset size() const { + return limits_.at(0); + } + + void advance() { + prevOffsets_ = cursor_.offsets; + cursor_.it.advance(lengths_, cursor_.offsets, sizes_, limits_, 1); + } + + private: + inline const TensorCPU& input(int32_t idx) const { + return inputs_[idx]->Get(); + } + + // TODO: Change to fieldDesc + inline const TreeIterator::FieldDesc& field(int idx) const { + return cursor_.it.fields().at(idx); + } + + inline int lengthIdx(int fieldId) const { + return field(fieldId).lengthFieldId + 1; + } + + inline TOffset offset(int fieldId) const { + return prevOffsets_[lengthIdx(fieldId)]; + } + + std::vector fieldDim(int fieldId) const { + auto tensorDim = input(fieldId).dims(); + tensorDim[0] = sizes_[lengthIdx(fieldId)]; + return tensorDim; + } + + void* fieldPtr(int fieldId) const { + auto& in = input(fieldId); + return (char*)in.raw_data() + + offset(fieldId) * in.size_from_dim(1) * in.meta().itemsize(); + } + + public: + // Simple Proxy class to expose nicer API for field access + class Field { + public: + Field(TreeWalker& walker, int fieldId) + : walker_(walker), fieldId_(fieldId) {} + + inline std::vector dim() const { + return walker_.fieldDim(fieldId_); + } + + inline const TypeMeta& meta() const { + return walker_.input(fieldId_).meta(); + } + + inline void* ptr() const { + return walker_.fieldPtr(fieldId_); + } + + private: + const TreeWalker& walker_; + const int fieldId_; + }; + + // Notice that a reference is returned. If advance() is called the fields will + // be updated to represent the new state. + inline const std::vector& fields() const { + return fields_; + } + + private: + void gatherLengthData() { + static const TLength lenZero = 0; + lengths_.resize(cursor_.it.numLengthFields()); + for (int i = 0; i < lengths_.size(); ++i) { + auto& in = input(cursor_.it.lengthField(i).id); + if (in.size() > 0) { + lengths_[i] = in.data(); + } else { + lengths_[i] = &lenZero; + } + } + } + + void gatherSizeLimits() { + limits_.assign(sizes_.size(), std::numeric_limits::max()); + for (auto fieldId = 0; fieldId < cursor_.it.fields().size(); ++fieldId) { + auto lengthFieldIdx = lengthIdx(fieldId); + limits_[lengthFieldIdx] = + std::min(limits_[lengthFieldIdx], (TOffset)input(fieldId).dims()[0]); + } + } + + const vector& inputs_; + TreeCursor& cursor_; + std::vector fields_; + + std::vector lengths_; + std::vector limits_; + std::vector sizes_; + std::vector offsets_; + std::vector prevOffsets_; +}; + +using SharedTensorVectorPtr = std::shared_ptr>; + +class PackRecordsOp : public Operator { + public: + PackRecordsOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + fields_(OperatorBase::GetRepeatedArgument("fields")) {} + + bool RunOnDevice() override { + // There should be one input per field + CAFFE_ENFORCE_EQ(InputSize(), fields_.size()); + CAFFE_ENFORCE_EQ(OutputSize(), 1); + + TreeCursor cursor((TreeIterator(fields_))); + + TreeWalker walker(Inputs(), cursor); + + Output(0)->Resize(walker.size()); + + // Output(0)->raw_mutable_data(TypeMeta::Make())); + auto* dst = Output(0)->mutable_data(); + + for (int batchId = 0; batchId < walker.size(); ++batchId) { + dst[batchId] = std::make_shared>(); + dst[batchId]->reserve(walker.fields().size()); + + for (const auto& field : walker.fields()) { + dst[batchId]->emplace_back(field.dim()); + auto& tensor = dst[batchId]->back(); + context_.template CopyItems( + field.meta(), + tensor.size(), + field.ptr() /* src */, + tensor.raw_mutable_data(field.meta()) /* dst */); + } + + walker.advance(); + } + + return true; + } + + private: + std::vector fields_; +}; + +class UnPackRecordsOp : public Operator { + public: + UnPackRecordsOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + fields_(OperatorBase::GetRepeatedArgument("fields")) {} + + bool RunOnDevice() override { + const auto* inputs = Input(0).template data(); + const auto numRows = Input(0).size(); + + CAFFE_ENFORCE_GE(numRows, 0); + + if (numRows == 0) { + return true; + } + + const auto& inputZero = inputs[0]; + CAFFE_ENFORCE(inputZero); + + const auto numTensors = inputZero->size(); + + CAFFE_ENFORCE_EQ(numTensors, fields_.size()); + CAFFE_ENFORCE_EQ(numTensors, OutputSize()); + + // Precomputer the output sizes to avoid resizing + std::vector> outputDims(numTensors); + + for (int i = 0; i < numTensors; ++i) { + outputDims[i] = inputs[0]->at(i).dims(); + outputDims[i][0] = 0; + } + + for (int i = 0; i < numRows; ++i) { + CAFFE_ENFORCE(inputs[i]); + for (int j = 0; j < inputs[i]->size(); ++j) { + const auto& input = inputs[i]->at(j); + const auto& inputZeroTensor = inputZero->at(j); + + // Checks to ensure that dimensions/sizes match + CAFFE_ENFORCE_EQ(inputZeroTensor.ndim(), input.ndim()); + CAFFE_ENFORCE(inputZeroTensor.meta() == input.meta()); + // We look from first dimension, because we concat on the first. + for (int k = 1; k < input.ndim(); ++k) { + CAFFE_ENFORCE_EQ(input.dims()[k], inputZeroTensor.dims()[k]); + } + + outputDims[j][0] += input.dim(0); + } + } + + // Resize to the final output size + std::vector destinations(numTensors); + for (int i = 0; i < numTensors; ++i) { + Output(i)->Resize(outputDims[i]); + destinations[i] = Output(i)->raw_mutable_data(inputZero->at(i).meta()); + } + + for (int i = 0; i < numRows; ++i) { + for (int j = 0; j < numTensors; ++j) { + const auto& input = inputs[i]->at(j); + // Skip empty tensors + if (input.size() == 0) { + continue; + } + + context_.CopyItems( + inputZero->at(j).meta(), + input.size(), + input.raw_data() /* src */, + destinations[j] /* dst */ + ); + + destinations[j] = + (char*)destinations[j] + input.size() * inputZero->at(j).itemsize(); + } + } + + return true; + } + + private: + std::vector fields_; +}; + class ReadNextBatchOp : public Operator { public: ReadNextBatchOp(const OperatorDef& operator_def, Workspace* ws) @@ -803,6 +1061,8 @@ REGISTER_CPU_OPERATOR(CreateTensorVector, CreateTensorVectorOp); REGISTER_CPU_OPERATOR(TensorVectorSize, TensorVectorSizeOp); REGISTER_CPU_OPERATOR(ConcatTensorVector, ConcatTensorVectorOp); REGISTER_CPU_OPERATOR(CollectTensor, CollectTensorOp); +REGISTER_CPU_OPERATOR(PackRecords, PackRecordsOp); +REGISTER_CPU_OPERATOR(UnPackRecords, UnPackRecordsOp); OPERATOR_SCHEMA(CreateTreeCursor) .NumInputs(0) @@ -1048,6 +1308,34 @@ output vectors. )DOC") .Arg("num_to_collect", "The max number of tensors to collect"); +OPERATOR_SCHEMA(PackRecords) + .NumInputs(1, INT_MAX) + .NumOutputs(1) + .SetDoc(R"DOC( + Given a dataset under a schema specified by the `fields` argument will pack all the input tensors into one, + where each tensor element represents a row of data (batch of size 1). This format allows easier use with the rest of Caffe2 operators. + )DOC") + .Arg( + "fields", + "List of strings representing the string names in the format" + "specified in the doc for CreateTreeCursor.") + .Output( + 0, + "tensor", + "One dimensional tensor having a complex type of SharedTensorVectorPtr. In order to reverse it back to the original input it has to be inserted into UnPackRecordsOp."); + +OPERATOR_SCHEMA(UnPackRecords) + .NumInputs(1) + .NumOutputs(1, INT_MAX) + .SetDoc(R"DOC( + Given a packed dataset (packed by the PackRecordsOp) and the `fields` argument describing the datasets schema returns the original dataset format. + Number of returned tensors is equal to the number of fields in the `fields` argument. + )DOC") + .Arg( + "fields", + "List of strings representing the string names in the format" + "specified in the doc for CreateTreeCursor."); + SHOULD_NOT_DO_GRADIENT(CreateTreeCursor); SHOULD_NOT_DO_GRADIENT(ResetCursor); SHOULD_NOT_DO_GRADIENT(ReadNextBatch); @@ -1060,9 +1348,12 @@ SHOULD_NOT_DO_GRADIENT(CreateTensorVector); SHOULD_NOT_DO_GRADIENT(TensorVectorSize); SHOULD_NOT_DO_GRADIENT(ConcatTensorVector); SHOULD_NOT_DO_GRADIENT(CollectTensor); +SHOULD_NOT_DO_GRADIENT(UnPack); +SHOULD_NOT_DO_GRADIENT(Pack); } // namespace CAFFE_KNOWN_TYPE(std::unique_ptr); CAFFE_KNOWN_TYPE(TensorVectorPtr); +CAFFE_KNOWN_TYPE(SharedTensorVectorPtr); namespace { diff --git a/caffe2/python/operator_test/dataset_ops_test.py b/caffe2/python/operator_test/dataset_ops_test.py index ef2d7abfb027..c355f47b3dfe 100644 --- a/caffe2/python/operator_test/dataset_ops_test.py +++ b/caffe2/python/operator_test/dataset_ops_test.py @@ -6,17 +6,27 @@ import numpy as np from caffe2.python import core, workspace, dataset from caffe2.python.dataset import Const from caffe2.python.schema import ( - List, Field, Struct, Scalar, Map, - from_blob_list, FetchRecord, NewRecord, FeedRecord) + List, Field, Struct, Scalar, Map, from_blob_list, FetchRecord, NewRecord, + FeedRecord +) from caffe2.python.test_util import TestCase +import numpy.testing as npt + +import string + +from hypothesis import given +import hypothesis.strategies as st + def _assert_arrays_equal(actual, ref, err_msg): if ref.dtype.kind in ('S', 'O'): np.testing.assert_array_equal(actual, ref, err_msg=err_msg) else: np.testing.assert_allclose( - actual, ref, atol=1e-4, rtol=1e-4, err_msg=err_msg) + actual, ref, atol=1e-4, + rtol=1e-4, err_msg=err_msg + ) def _assert_records_equal(actual, ref): @@ -24,13 +34,169 @@ def _assert_records_equal(actual, ref): assert isinstance(ref, Field) b1 = actual.field_blobs() b2 = ref.field_blobs() - assert(len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % ( - len(b1), len(b2)) + assert (len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % ( + len(b1), len(b2) + ) for name, d1, d2 in zip(ref.field_names(), b1, b2): _assert_arrays_equal(d1, d2, err_msg='Mismatch in field %s.' % name) +@st.composite +def _sparse_features_map(draw, num_records, **kwargs): + sparse_maps_lengths = draw( + st.lists( + st.integers(min_value=1, max_value=10), + min_size=num_records, + max_size=num_records + ) + ) + + sparse_maps_total_length = sum(sparse_maps_lengths) + + sparse_keys = draw( + st.lists( + st.integers(min_value=1, max_value=100), + min_size=sparse_maps_total_length, + max_size=sparse_maps_total_length, + unique=True + ) + ) + + sparse_values_lengths = draw( + st.lists( + st.integers(min_value=1, max_value=10), + min_size=sparse_maps_total_length, + max_size=sparse_maps_total_length + ) + ) + + total_sparse_values_lengths = sum(sparse_values_lengths) + + sparse_values = draw( + # max_value is max int64 + st.lists( + st.integers(min_value=1, max_value=9223372036854775807), + min_size=total_sparse_values_lengths, + max_size=total_sparse_values_lengths + ) + ) + + return [ + sparse_maps_lengths, + sparse_keys, + sparse_values_lengths, + sparse_values, + ] + + +@st.composite +def _dense_features_map(draw, num_records, **kwargs): + float_lengths = draw( + st.lists( + st.integers(min_value=1, max_value=10), + min_size=num_records, + max_size=num_records + ) + ) + + total_length = sum(float_lengths) + + float_keys = draw( + st.lists( + st.integers(min_value=1, max_value=100), + min_size=total_length, + max_size=total_length, + unique=True + ) + ) + + float_values = draw( + st.lists(st.floats(), + min_size=total_length, + max_size=total_length) + ) + + return [float_lengths, float_keys, float_values] + + +@st.composite +def _dataset(draw, min_elements=3, max_elements=10, **kwargs): + schema = Struct( + # Dense Features Map + ('floats', Map( + Scalar(np.int32), Scalar(np.float32) + )), + # Sparse Features Map + ('int_lists', Map( + Scalar(np.int32), + List(Scalar(np.int64)), + )), + # Complex Type + ('text', Scalar(str)), + ) + + num_records = draw( + st.integers(min_value=min_elements, + max_value=max_elements) + ) + + raw_dense_features_map_contents = draw(_dense_features_map(num_records)) + + raw_sparse_features_map_contents = draw(_sparse_features_map(num_records)) + + raw_text_contents = [ + draw( + st.lists( + st.text(alphabet=string.ascii_lowercase), + min_size=num_records, + max_size=num_records + ) + ) + ] + + # Concatenate all raw contents to a single one + contents_raw = raw_dense_features_map_contents + raw_sparse_features_map_contents + raw_text_contents + + contents = from_blob_list(schema, contents_raw) + + return (schema, contents, num_records) + + class TestDatasetOps(TestCase): + @given(_dataset()) + def test_pack_unpack(self, input): + """ + Tests if packing and unpacking of the whole dataset is an identity. + """ + (schema, contents, num_records) = input + + dataset_fields = schema.field_names() + + net = core.Net('pack_unpack_net') + + batch = NewRecord(net, contents) + FeedRecord(batch, contents) + + packed = net.PackRecords( + batch.field_blobs(), 1, + fields=dataset_fields + ) + + unpacked = packed.UnPackRecords( + [], len(dataset_fields), + fields=dataset_fields + ) + + workspace.RunNetOnce(net) + + for initial_tensor, unpacked_tensor in zip( + batch.field_blobs(), unpacked + ): + npt.assert_array_equal( + workspace.FetchBlob(initial_tensor), + workspace.FetchBlob(unpacked_tensor) + ) + def test_dataset_ops(self): """ 1. Defining the schema of our dataset. @@ -42,30 +208,34 @@ class TestDatasetOps(TestCase): ('dense', Scalar((np.float32, 3))), # could represent a feature map from feature ID to float value ('floats', Map( - Scalar(np.int32), - Scalar(np.float32))), + Scalar(np.int32), Scalar(np.float32) + )), # could represent a multi-valued categorical feature map ('int_lists', Map( Scalar(np.int32), List(Scalar(np.int64)), )), # could represent a multi-valued, weighted categorical feature map - ('id_score_pairs', Map( - Scalar(np.int32), - Map( - Scalar(np.int64), - Scalar(np.float32), - keys_name='ids', - values_name='scores'), - )), + ( + 'id_score_pairs', Map( + Scalar(np.int32), + Map( + Scalar(np.int64), + Scalar(np.float32), + keys_name='ids', + values_name='scores' + ), + ) + ), # additional scalar information - ('metadata', Struct( - ('user_id', Scalar(np.int64)), - ('user_embed', Scalar((np.float32, 2))), - ('query', Scalar(str)), - )), + ( + 'metadata', Struct( + ('user_id', Scalar(np.int64)), + ('user_embed', Scalar((np.float32, 2))), + ('query', Scalar(str)), + ) + ), ) - """ This is what the flattened fields for this schema look like, along with its type. Each one of these fields will be stored, read and @@ -90,13 +260,11 @@ class TestDatasetOps(TestCase): ('metadata:query', str), ] zipped = zip( - expected_fields, - schema.field_names(), - schema.field_types()) + expected_fields, schema.field_names(), schema.field_types() + ) for (ref_name, ref_type), name, dtype in zipped: self.assertEquals(ref_name, name) self.assertEquals(np.dtype(ref_type), dtype) - """ 2. The contents of our dataset. @@ -129,7 +297,6 @@ class TestDatasetOps(TestCase): ] # convert the above content to ndarrays, checking against the schema contents = from_blob_list(schema, contents_raw) - """ 3. Creating and appending to the dataset. We first create an empty dataset with the given schema. @@ -145,7 +312,6 @@ class TestDatasetOps(TestCase): writer = ds.writer(init_net=net) writer.write_record(net, content_blobs) workspace.RunNetOnce(net) - """ 4. Iterating through the dataset contents. @@ -155,32 +321,63 @@ class TestDatasetOps(TestCase): entries_raw = [ ( [[1.1, 1.2, 1.3]], # dense - [1], [11], [1.1], # floats - [2], [11, 12], [2, 4], [111, 112, 121, 122, 123, 124], # intlst - [1], [11], [1], [111], [11.1], # id score pairs - [123], [[0.2, 0.8]], ['dog posts'], # metadata + [1], + [11], + [1.1], # floats + [2], + [11, 12], + [2, 4], + [111, 112, 121, 122, 123, 124], # intlst + [1], + [11], + [1], + [111], + [11.1], # id score pairs + [123], + [[0.2, 0.8]], + ['dog posts'], # metadata ), ( [[2.1, 2.2, 2.3]], # dense - [2], [21, 22], [2.1, 2.2], # floats - [0], [], [], [], # int list - [2], [21, 22], [1, 2], [211, 221, 222], [21.1, 22.1, 22.2], - [234], [[0.5, 0.5]], ['friends who like to'], # metadata + [2], + [21, 22], + [2.1, 2.2], # floats + [0], + [], + [], + [], # int list + [2], + [21, 22], + [1, 2], + [211, 221, 222], + [21.1, 22.1, 22.2], + [234], + [[0.5, 0.5]], + ['friends who like to'], # metadata ), ( [[3.1, 3.2, 3.3]], # dense - [3], [31, 32, 33], [3.1, 3.2, 3.3], # floats - [1], [31], [3], [311, 312, 313], # int lst - [2], [31, 32], [2, 3], [311, 312, 321, 322, 323], + [3], + [31, 32, 33], + [3.1, 3.2, 3.3], # floats + [1], + [31], + [3], + [311, 312, 313], # int lst + [2], + [31, 32], + [2, 3], + [311, 312, 321, 322, 323], [31.1, 31.2, 32.1, 32.2, 32.3], # id score list - [456], [[0.7, 0.3]], ['posts about ca'], # metadata + [456], + [[0.7, 0.3]], + ['posts about ca'], # metadata ), # after the end of the dataset, we will keep getting empty vectors - ([],) * 16, - ([],) * 16, + ([], ) * 16, + ([], ) * 16, ] entries = [from_blob_list(schema, e) for e in entries_raw] - """ Let's go ahead and create the reading nets. We will run `read` net multiple times and assert that we are reading the @@ -198,7 +395,6 @@ class TestDatasetOps(TestCase): workspace.RunNet(str(read_next_net)) actual = FetchRecord(batch) _assert_records_equal(actual, entry) - """ 5. Reading/writing in a single plan @@ -212,7 +408,6 @@ class TestDatasetOps(TestCase): reset_net = core.Net('reset_net') reader.reset(reset_net) read_step, batch = reader.execution_step() - """ We will add the line number * 1000 to the feature ids. """ process_net = core.Net('process') line_no = Const(process_net, 0, dtype=np.int32) @@ -221,7 +416,6 @@ class TestDatasetOps(TestCase): field = batch.floats.keys.get() process_net.Print(field, []) process_net.Add([field, line_no], field, broadcast=1, axis=0) - """ Lets create a second dataset and append to it. """ ds2 = dataset.Dataset(schema, name='dataset2') ds2.init_empty(reset_net) @@ -231,14 +425,12 @@ class TestDatasetOps(TestCase): # generality of the example commit_net = core.Net('commit') writer.commit(commit_net) - """ Time to create and run a plan which will do the processing """ plan = core.Plan('process') plan.AddStep(core.execution_step('reset', reset_net)) plan.AddStep(read_step.AddNet(process_net)) plan.AddStep(core.execution_step('commit', commit_net)) workspace.RunPlan(plan) - """ Now we should have dataset2 populated. """ @@ -246,7 +438,6 @@ class TestDatasetOps(TestCase): field = ds2_data.floats.keys field.set(blob=field.get() - [1000, 2000, 2000, 3000, 3000, 3000]) _assert_records_equal(contents, ds2_data) - """ 6. Slicing a dataset @@ -256,7 +447,6 @@ class TestDatasetOps(TestCase): subschema = Struct(('top_level', schema.int_lists.values)) int_list_contents = contents.int_lists.values.field_names() self.assertEquals(len(subschema.field_names()), len(int_list_contents)) - """ 7. Random Access a dataset @@ -282,8 +472,6 @@ class TestDatasetOps(TestCase): workspace.RunNet(str(read_next_net)) actual = FetchRecord(batch) _assert_records_equal(actual, entry) - - """ 8. Sort and shuffle a dataset @@ -328,31 +516,40 @@ class TestDatasetOps(TestCase): num_to_collect=7, ) plan = core.Plan('collect_data') - plan.AddStep(core.execution_step('collect_data', - [collect_net], num_iter=1)) + plan.AddStep( + core.execution_step('collect_data', [collect_net], + num_iter=1) + ) workspace.RunPlan(plan) reference_result = workspace.FetchBlob('output') self.assertSequenceEqual( [item for sublist in reference_result for item in sublist], - [1, 2, 3, 4, 5, 6]) + [1, 2, 3, 4, 5, 6] + ) plan = core.Plan('collect_data') - plan.AddStep(core.execution_step('collect_data', - [collect_net], num_iter=2)) + plan.AddStep( + core.execution_step('collect_data', [collect_net], + num_iter=2) + ) workspace.RunPlan(plan) reference_result = workspace.FetchBlob('output') self.assertSequenceEqual( [item for sublist in reference_result for item in sublist], - [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]) + [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6] + ) plan = core.Plan('collect_data') - plan.AddStep(core.execution_step('collect_data', - [collect_net], num_iter=3)) + plan.AddStep( + core.execution_step('collect_data', [collect_net], + num_iter=3) + ) workspace.RunPlan(plan) reference_result = workspace.FetchBlob('output') self.assertSequenceEqual( [item for sublist in reference_result for item in sublist], - [3, 4, 5, 6, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2]) + [3, 4, 5, 6, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2] + ) def test_collect_tensor_ops(self): init_net = core.Net('init_net') @@ -382,9 +579,12 @@ class TestDatasetOps(TestCase): plan = core.Plan('collect_data') plan.AddStep(core.execution_step('collect_init', init_net)) - plan.AddStep(core.execution_step('collect_data', - [reader_net, collect_net], - num_iter=max_example_to_cover)) + plan.AddStep( + core.execution_step( + 'collect_data', [reader_net, collect_net], + num_iter=max_example_to_cover + ) + ) workspace.RunPlan(plan) # concat the collected tensors @@ -401,14 +601,19 @@ class TestDatasetOps(TestCase): # check data reference_result = workspace.FetchBlob(bconcated_map[blobs[0]]) - self.assertEqual(reference_result.shape, - (min(num_to_collect, max_example_to_cover), 2)) + self.assertEqual( + reference_result.shape, + (min(num_to_collect, max_example_to_cover), 2) + ) size = workspace.FetchBlob(bsize_map[blobs[0]]) self.assertEqual(tuple(), size.shape) self.assertEqual(min(num_to_collect, max_example_to_cover), size.item()) - hist, _ = np.histogram(reference_result[:, 0], bins=10, - range=(1, max_example_to_cover)) + hist, _ = np.histogram( + reference_result[:, 0], + bins=10, + range=(1, max_example_to_cover) + ) print('Sample histogram: {}'.format(hist)) self.assertTrue(all(hist > 0.7 * (num_to_collect / 10))) @@ -416,6 +621,7 @@ class TestDatasetOps(TestCase): result = workspace.FetchBlob(bconcated_map[blobs[i]]) self.assertEqual(reference_result.tolist(), result.tolist()) + if __name__ == "__main__": import unittest unittest.main()