caffe2: datasets pack/unpack

Summary:
Two new operators to pack and unpack a dataset. This is so that we can
re-use other operators that do not understand the schema format. The immediate
use-case is to use it with a partition operator.

Packing works by splitting the input into separate tensors, putting them in a
vector and wrapping in a shared_ptr (as opposed to a unique_ptr, so we can
copy).

Unpack takes the packed input and concatenates it back to the original.

I also had a gard time understanding the iteration, so I created a TreeWalker
that just hides the complexity of operating with all the arrays and makes the
short functions for a given purpose that at least for me are easier to
understand.

Reviewed By: dzhulgakov

Differential Revision: D4918002

fbshipit-source-id: ecbf9196ed25e886a94383961176b8c84dde2d2f
This commit is contained in:
Janusz Kudelka
2017-04-24 15:59:38 -07:00
committed by Facebook Github Bot
parent 9cb901caf0
commit 902409be56
2 changed files with 564 additions and 67 deletions

View File

@ -274,6 +274,264 @@ class CheckDatasetConsistencyOp : public Operator<CPUContext> {
TreeIterator iterator_;
};
/**
* Simple wrapper class allowing an easy traversal of the tensors representing
* the hirerarchical structure.
*/
class TreeWalker {
public:
TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor)
: inputs_(inputs), cursor_(cursor), sizes_(cursor.it.numOffsetFields()) {
if (cursor.offsets.empty()) {
cursor.offsets.assign(cursor.it.numOffsetFields(), 0);
}
for (int fieldId = 0; fieldId < cursor_.it.fields().size(); ++fieldId) {
fields_.emplace_back(*this, fieldId);
}
gatherLengthData();
gatherSizeLimits();
// The invariant we hold is that we are always one step ahead
advance();
}
// Returns the number of records in a dataset
inline TOffset size() const {
return limits_.at(0);
}
void advance() {
prevOffsets_ = cursor_.offsets;
cursor_.it.advance(lengths_, cursor_.offsets, sizes_, limits_, 1);
}
private:
inline const TensorCPU& input(int32_t idx) const {
return inputs_[idx]->Get<TensorCPU>();
}
// TODO: Change to fieldDesc
inline const TreeIterator::FieldDesc& field(int idx) const {
return cursor_.it.fields().at(idx);
}
inline int lengthIdx(int fieldId) const {
return field(fieldId).lengthFieldId + 1;
}
inline TOffset offset(int fieldId) const {
return prevOffsets_[lengthIdx(fieldId)];
}
std::vector<TIndex> fieldDim(int fieldId) const {
auto tensorDim = input(fieldId).dims();
tensorDim[0] = sizes_[lengthIdx(fieldId)];
return tensorDim;
}
void* fieldPtr(int fieldId) const {
auto& in = input(fieldId);
return (char*)in.raw_data() +
offset(fieldId) * in.size_from_dim(1) * in.meta().itemsize();
}
public:
// Simple Proxy class to expose nicer API for field access
class Field {
public:
Field(TreeWalker& walker, int fieldId)
: walker_(walker), fieldId_(fieldId) {}
inline std::vector<TIndex> dim() const {
return walker_.fieldDim(fieldId_);
}
inline const TypeMeta& meta() const {
return walker_.input(fieldId_).meta();
}
inline void* ptr() const {
return walker_.fieldPtr(fieldId_);
}
private:
const TreeWalker& walker_;
const int fieldId_;
};
// Notice that a reference is returned. If advance() is called the fields will
// be updated to represent the new state.
inline const std::vector<Field>& fields() const {
return fields_;
}
private:
void gatherLengthData() {
static const TLength lenZero = 0;
lengths_.resize(cursor_.it.numLengthFields());
for (int i = 0; i < lengths_.size(); ++i) {
auto& in = input(cursor_.it.lengthField(i).id);
if (in.size() > 0) {
lengths_[i] = in.data<int>();
} else {
lengths_[i] = &lenZero;
}
}
}
void gatherSizeLimits() {
limits_.assign(sizes_.size(), std::numeric_limits<TOffset>::max());
for (auto fieldId = 0; fieldId < cursor_.it.fields().size(); ++fieldId) {
auto lengthFieldIdx = lengthIdx(fieldId);
limits_[lengthFieldIdx] =
std::min(limits_[lengthFieldIdx], (TOffset)input(fieldId).dims()[0]);
}
}
const vector<const Blob*>& inputs_;
TreeCursor& cursor_;
std::vector<Field> fields_;
std::vector<const TLength*> lengths_;
std::vector<TOffset> limits_;
std::vector<TOffset> sizes_;
std::vector<TOffset> offsets_;
std::vector<TOffset> prevOffsets_;
};
using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
class PackRecordsOp : public Operator<CPUContext> {
public:
PackRecordsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws),
fields_(OperatorBase::GetRepeatedArgument<std::string>("fields")) {}
bool RunOnDevice() override {
// There should be one input per field
CAFFE_ENFORCE_EQ(InputSize(), fields_.size());
CAFFE_ENFORCE_EQ(OutputSize(), 1);
TreeCursor cursor((TreeIterator(fields_)));
TreeWalker walker(Inputs(), cursor);
Output(0)->Resize(walker.size());
// Output(0)->raw_mutable_data(TypeMeta::Make<SharedTensorVectorPtr>()));
auto* dst = Output(0)->mutable_data<SharedTensorVectorPtr>();
for (int batchId = 0; batchId < walker.size(); ++batchId) {
dst[batchId] = std::make_shared<std::vector<TensorCPU>>();
dst[batchId]->reserve(walker.fields().size());
for (const auto& field : walker.fields()) {
dst[batchId]->emplace_back(field.dim());
auto& tensor = dst[batchId]->back();
context_.template CopyItems<CPUContext, CPUContext>(
field.meta(),
tensor.size(),
field.ptr() /* src */,
tensor.raw_mutable_data(field.meta()) /* dst */);
}
walker.advance();
}
return true;
}
private:
std::vector<std::string> fields_;
};
class UnPackRecordsOp : public Operator<CPUContext> {
public:
UnPackRecordsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws),
fields_(OperatorBase::GetRepeatedArgument<std::string>("fields")) {}
bool RunOnDevice() override {
const auto* inputs = Input(0).template data<SharedTensorVectorPtr>();
const auto numRows = Input(0).size();
CAFFE_ENFORCE_GE(numRows, 0);
if (numRows == 0) {
return true;
}
const auto& inputZero = inputs[0];
CAFFE_ENFORCE(inputZero);
const auto numTensors = inputZero->size();
CAFFE_ENFORCE_EQ(numTensors, fields_.size());
CAFFE_ENFORCE_EQ(numTensors, OutputSize());
// Precomputer the output sizes to avoid resizing
std::vector<std::vector<TIndex>> outputDims(numTensors);
for (int i = 0; i < numTensors; ++i) {
outputDims[i] = inputs[0]->at(i).dims();
outputDims[i][0] = 0;
}
for (int i = 0; i < numRows; ++i) {
CAFFE_ENFORCE(inputs[i]);
for (int j = 0; j < inputs[i]->size(); ++j) {
const auto& input = inputs[i]->at(j);
const auto& inputZeroTensor = inputZero->at(j);
// Checks to ensure that dimensions/sizes match
CAFFE_ENFORCE_EQ(inputZeroTensor.ndim(), input.ndim());
CAFFE_ENFORCE(inputZeroTensor.meta() == input.meta());
// We look from first dimension, because we concat on the first.
for (int k = 1; k < input.ndim(); ++k) {
CAFFE_ENFORCE_EQ(input.dims()[k], inputZeroTensor.dims()[k]);
}
outputDims[j][0] += input.dim(0);
}
}
// Resize to the final output size
std::vector<void*> destinations(numTensors);
for (int i = 0; i < numTensors; ++i) {
Output(i)->Resize(outputDims[i]);
destinations[i] = Output(i)->raw_mutable_data(inputZero->at(i).meta());
}
for (int i = 0; i < numRows; ++i) {
for (int j = 0; j < numTensors; ++j) {
const auto& input = inputs[i]->at(j);
// Skip empty tensors
if (input.size() == 0) {
continue;
}
context_.CopyItems<CPUContext, CPUContext>(
inputZero->at(j).meta(),
input.size(),
input.raw_data() /* src */,
destinations[j] /* dst */
);
destinations[j] =
(char*)destinations[j] + input.size() * inputZero->at(j).itemsize();
}
}
return true;
}
private:
std::vector<std::string> fields_;
};
class ReadNextBatchOp : public Operator<CPUContext> {
public:
ReadNextBatchOp(const OperatorDef& operator_def, Workspace* ws)
@ -803,6 +1061,8 @@ REGISTER_CPU_OPERATOR(CreateTensorVector, CreateTensorVectorOp<CPUContext>);
REGISTER_CPU_OPERATOR(TensorVectorSize, TensorVectorSizeOp<CPUContext>);
REGISTER_CPU_OPERATOR(ConcatTensorVector, ConcatTensorVectorOp<CPUContext>);
REGISTER_CPU_OPERATOR(CollectTensor, CollectTensorOp<CPUContext>);
REGISTER_CPU_OPERATOR(PackRecords, PackRecordsOp);
REGISTER_CPU_OPERATOR(UnPackRecords, UnPackRecordsOp);
OPERATOR_SCHEMA(CreateTreeCursor)
.NumInputs(0)
@ -1048,6 +1308,34 @@ output vectors.
)DOC")
.Arg("num_to_collect", "The max number of tensors to collect");
OPERATOR_SCHEMA(PackRecords)
.NumInputs(1, INT_MAX)
.NumOutputs(1)
.SetDoc(R"DOC(
Given a dataset under a schema specified by the `fields` argument will pack all the input tensors into one,
where each tensor element represents a row of data (batch of size 1). This format allows easier use with the rest of Caffe2 operators.
)DOC")
.Arg(
"fields",
"List of strings representing the string names in the format"
"specified in the doc for CreateTreeCursor.")
.Output(
0,
"tensor",
"One dimensional tensor having a complex type of SharedTensorVectorPtr. In order to reverse it back to the original input it has to be inserted into UnPackRecordsOp.");
OPERATOR_SCHEMA(UnPackRecords)
.NumInputs(1)
.NumOutputs(1, INT_MAX)
.SetDoc(R"DOC(
Given a packed dataset (packed by the PackRecordsOp) and the `fields` argument describing the datasets schema returns the original dataset format.
Number of returned tensors is equal to the number of fields in the `fields` argument.
)DOC")
.Arg(
"fields",
"List of strings representing the string names in the format"
"specified in the doc for CreateTreeCursor.");
SHOULD_NOT_DO_GRADIENT(CreateTreeCursor);
SHOULD_NOT_DO_GRADIENT(ResetCursor);
SHOULD_NOT_DO_GRADIENT(ReadNextBatch);
@ -1060,9 +1348,12 @@ SHOULD_NOT_DO_GRADIENT(CreateTensorVector);
SHOULD_NOT_DO_GRADIENT(TensorVectorSize);
SHOULD_NOT_DO_GRADIENT(ConcatTensorVector);
SHOULD_NOT_DO_GRADIENT(CollectTensor);
SHOULD_NOT_DO_GRADIENT(UnPack);
SHOULD_NOT_DO_GRADIENT(Pack);
} // namespace
CAFFE_KNOWN_TYPE(std::unique_ptr<TreeCursor>);
CAFFE_KNOWN_TYPE(TensorVectorPtr<CPUContext>);
CAFFE_KNOWN_TYPE(SharedTensorVectorPtr);
namespace {

View File

@ -6,17 +6,27 @@ import numpy as np
from caffe2.python import core, workspace, dataset
from caffe2.python.dataset import Const
from caffe2.python.schema import (
List, Field, Struct, Scalar, Map,
from_blob_list, FetchRecord, NewRecord, FeedRecord)
List, Field, Struct, Scalar, Map, from_blob_list, FetchRecord, NewRecord,
FeedRecord
)
from caffe2.python.test_util import TestCase
import numpy.testing as npt
import string
from hypothesis import given
import hypothesis.strategies as st
def _assert_arrays_equal(actual, ref, err_msg):
if ref.dtype.kind in ('S', 'O'):
np.testing.assert_array_equal(actual, ref, err_msg=err_msg)
else:
np.testing.assert_allclose(
actual, ref, atol=1e-4, rtol=1e-4, err_msg=err_msg)
actual, ref, atol=1e-4,
rtol=1e-4, err_msg=err_msg
)
def _assert_records_equal(actual, ref):
@ -24,13 +34,169 @@ def _assert_records_equal(actual, ref):
assert isinstance(ref, Field)
b1 = actual.field_blobs()
b2 = ref.field_blobs()
assert(len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % (
len(b1), len(b2))
assert (len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % (
len(b1), len(b2)
)
for name, d1, d2 in zip(ref.field_names(), b1, b2):
_assert_arrays_equal(d1, d2, err_msg='Mismatch in field %s.' % name)
@st.composite
def _sparse_features_map(draw, num_records, **kwargs):
sparse_maps_lengths = draw(
st.lists(
st.integers(min_value=1, max_value=10),
min_size=num_records,
max_size=num_records
)
)
sparse_maps_total_length = sum(sparse_maps_lengths)
sparse_keys = draw(
st.lists(
st.integers(min_value=1, max_value=100),
min_size=sparse_maps_total_length,
max_size=sparse_maps_total_length,
unique=True
)
)
sparse_values_lengths = draw(
st.lists(
st.integers(min_value=1, max_value=10),
min_size=sparse_maps_total_length,
max_size=sparse_maps_total_length
)
)
total_sparse_values_lengths = sum(sparse_values_lengths)
sparse_values = draw(
# max_value is max int64
st.lists(
st.integers(min_value=1, max_value=9223372036854775807),
min_size=total_sparse_values_lengths,
max_size=total_sparse_values_lengths
)
)
return [
sparse_maps_lengths,
sparse_keys,
sparse_values_lengths,
sparse_values,
]
@st.composite
def _dense_features_map(draw, num_records, **kwargs):
float_lengths = draw(
st.lists(
st.integers(min_value=1, max_value=10),
min_size=num_records,
max_size=num_records
)
)
total_length = sum(float_lengths)
float_keys = draw(
st.lists(
st.integers(min_value=1, max_value=100),
min_size=total_length,
max_size=total_length,
unique=True
)
)
float_values = draw(
st.lists(st.floats(),
min_size=total_length,
max_size=total_length)
)
return [float_lengths, float_keys, float_values]
@st.composite
def _dataset(draw, min_elements=3, max_elements=10, **kwargs):
schema = Struct(
# Dense Features Map
('floats', Map(
Scalar(np.int32), Scalar(np.float32)
)),
# Sparse Features Map
('int_lists', Map(
Scalar(np.int32),
List(Scalar(np.int64)),
)),
# Complex Type
('text', Scalar(str)),
)
num_records = draw(
st.integers(min_value=min_elements,
max_value=max_elements)
)
raw_dense_features_map_contents = draw(_dense_features_map(num_records))
raw_sparse_features_map_contents = draw(_sparse_features_map(num_records))
raw_text_contents = [
draw(
st.lists(
st.text(alphabet=string.ascii_lowercase),
min_size=num_records,
max_size=num_records
)
)
]
# Concatenate all raw contents to a single one
contents_raw = raw_dense_features_map_contents + raw_sparse_features_map_contents + raw_text_contents
contents = from_blob_list(schema, contents_raw)
return (schema, contents, num_records)
class TestDatasetOps(TestCase):
@given(_dataset())
def test_pack_unpack(self, input):
"""
Tests if packing and unpacking of the whole dataset is an identity.
"""
(schema, contents, num_records) = input
dataset_fields = schema.field_names()
net = core.Net('pack_unpack_net')
batch = NewRecord(net, contents)
FeedRecord(batch, contents)
packed = net.PackRecords(
batch.field_blobs(), 1,
fields=dataset_fields
)
unpacked = packed.UnPackRecords(
[], len(dataset_fields),
fields=dataset_fields
)
workspace.RunNetOnce(net)
for initial_tensor, unpacked_tensor in zip(
batch.field_blobs(), unpacked
):
npt.assert_array_equal(
workspace.FetchBlob(initial_tensor),
workspace.FetchBlob(unpacked_tensor)
)
def test_dataset_ops(self):
"""
1. Defining the schema of our dataset.
@ -42,30 +208,34 @@ class TestDatasetOps(TestCase):
('dense', Scalar((np.float32, 3))),
# could represent a feature map from feature ID to float value
('floats', Map(
Scalar(np.int32),
Scalar(np.float32))),
Scalar(np.int32), Scalar(np.float32)
)),
# could represent a multi-valued categorical feature map
('int_lists', Map(
Scalar(np.int32),
List(Scalar(np.int64)),
)),
# could represent a multi-valued, weighted categorical feature map
('id_score_pairs', Map(
Scalar(np.int32),
Map(
Scalar(np.int64),
Scalar(np.float32),
keys_name='ids',
values_name='scores'),
)),
(
'id_score_pairs', Map(
Scalar(np.int32),
Map(
Scalar(np.int64),
Scalar(np.float32),
keys_name='ids',
values_name='scores'
),
)
),
# additional scalar information
('metadata', Struct(
('user_id', Scalar(np.int64)),
('user_embed', Scalar((np.float32, 2))),
('query', Scalar(str)),
)),
(
'metadata', Struct(
('user_id', Scalar(np.int64)),
('user_embed', Scalar((np.float32, 2))),
('query', Scalar(str)),
)
),
)
"""
This is what the flattened fields for this schema look like, along
with its type. Each one of these fields will be stored, read and
@ -90,13 +260,11 @@ class TestDatasetOps(TestCase):
('metadata:query', str),
]
zipped = zip(
expected_fields,
schema.field_names(),
schema.field_types())
expected_fields, schema.field_names(), schema.field_types()
)
for (ref_name, ref_type), name, dtype in zipped:
self.assertEquals(ref_name, name)
self.assertEquals(np.dtype(ref_type), dtype)
"""
2. The contents of our dataset.
@ -129,7 +297,6 @@ class TestDatasetOps(TestCase):
]
# convert the above content to ndarrays, checking against the schema
contents = from_blob_list(schema, contents_raw)
"""
3. Creating and appending to the dataset.
We first create an empty dataset with the given schema.
@ -145,7 +312,6 @@ class TestDatasetOps(TestCase):
writer = ds.writer(init_net=net)
writer.write_record(net, content_blobs)
workspace.RunNetOnce(net)
"""
4. Iterating through the dataset contents.
@ -155,32 +321,63 @@ class TestDatasetOps(TestCase):
entries_raw = [
(
[[1.1, 1.2, 1.3]], # dense
[1], [11], [1.1], # floats
[2], [11, 12], [2, 4], [111, 112, 121, 122, 123, 124], # intlst
[1], [11], [1], [111], [11.1], # id score pairs
[123], [[0.2, 0.8]], ['dog posts'], # metadata
[1],
[11],
[1.1], # floats
[2],
[11, 12],
[2, 4],
[111, 112, 121, 122, 123, 124], # intlst
[1],
[11],
[1],
[111],
[11.1], # id score pairs
[123],
[[0.2, 0.8]],
['dog posts'], # metadata
),
(
[[2.1, 2.2, 2.3]], # dense
[2], [21, 22], [2.1, 2.2], # floats
[0], [], [], [], # int list
[2], [21, 22], [1, 2], [211, 221, 222], [21.1, 22.1, 22.2],
[234], [[0.5, 0.5]], ['friends who like to'], # metadata
[2],
[21, 22],
[2.1, 2.2], # floats
[0],
[],
[],
[], # int list
[2],
[21, 22],
[1, 2],
[211, 221, 222],
[21.1, 22.1, 22.2],
[234],
[[0.5, 0.5]],
['friends who like to'], # metadata
),
(
[[3.1, 3.2, 3.3]], # dense
[3], [31, 32, 33], [3.1, 3.2, 3.3], # floats
[1], [31], [3], [311, 312, 313], # int lst
[2], [31, 32], [2, 3], [311, 312, 321, 322, 323],
[3],
[31, 32, 33],
[3.1, 3.2, 3.3], # floats
[1],
[31],
[3],
[311, 312, 313], # int lst
[2],
[31, 32],
[2, 3],
[311, 312, 321, 322, 323],
[31.1, 31.2, 32.1, 32.2, 32.3], # id score list
[456], [[0.7, 0.3]], ['posts about ca'], # metadata
[456],
[[0.7, 0.3]],
['posts about ca'], # metadata
),
# after the end of the dataset, we will keep getting empty vectors
([],) * 16,
([],) * 16,
([], ) * 16,
([], ) * 16,
]
entries = [from_blob_list(schema, e) for e in entries_raw]
"""
Let's go ahead and create the reading nets.
We will run `read` net multiple times and assert that we are reading the
@ -198,7 +395,6 @@ class TestDatasetOps(TestCase):
workspace.RunNet(str(read_next_net))
actual = FetchRecord(batch)
_assert_records_equal(actual, entry)
"""
5. Reading/writing in a single plan
@ -212,7 +408,6 @@ class TestDatasetOps(TestCase):
reset_net = core.Net('reset_net')
reader.reset(reset_net)
read_step, batch = reader.execution_step()
""" We will add the line number * 1000 to the feature ids. """
process_net = core.Net('process')
line_no = Const(process_net, 0, dtype=np.int32)
@ -221,7 +416,6 @@ class TestDatasetOps(TestCase):
field = batch.floats.keys.get()
process_net.Print(field, [])
process_net.Add([field, line_no], field, broadcast=1, axis=0)
""" Lets create a second dataset and append to it. """
ds2 = dataset.Dataset(schema, name='dataset2')
ds2.init_empty(reset_net)
@ -231,14 +425,12 @@ class TestDatasetOps(TestCase):
# generality of the example
commit_net = core.Net('commit')
writer.commit(commit_net)
""" Time to create and run a plan which will do the processing """
plan = core.Plan('process')
plan.AddStep(core.execution_step('reset', reset_net))
plan.AddStep(read_step.AddNet(process_net))
plan.AddStep(core.execution_step('commit', commit_net))
workspace.RunPlan(plan)
"""
Now we should have dataset2 populated.
"""
@ -246,7 +438,6 @@ class TestDatasetOps(TestCase):
field = ds2_data.floats.keys
field.set(blob=field.get() - [1000, 2000, 2000, 3000, 3000, 3000])
_assert_records_equal(contents, ds2_data)
"""
6. Slicing a dataset
@ -256,7 +447,6 @@ class TestDatasetOps(TestCase):
subschema = Struct(('top_level', schema.int_lists.values))
int_list_contents = contents.int_lists.values.field_names()
self.assertEquals(len(subschema.field_names()), len(int_list_contents))
"""
7. Random Access a dataset
@ -282,8 +472,6 @@ class TestDatasetOps(TestCase):
workspace.RunNet(str(read_next_net))
actual = FetchRecord(batch)
_assert_records_equal(actual, entry)
"""
8. Sort and shuffle a dataset
@ -328,31 +516,40 @@ class TestDatasetOps(TestCase):
num_to_collect=7,
)
plan = core.Plan('collect_data')
plan.AddStep(core.execution_step('collect_data',
[collect_net], num_iter=1))
plan.AddStep(
core.execution_step('collect_data', [collect_net],
num_iter=1)
)
workspace.RunPlan(plan)
reference_result = workspace.FetchBlob('output')
self.assertSequenceEqual(
[item for sublist in reference_result for item in sublist],
[1, 2, 3, 4, 5, 6])
[1, 2, 3, 4, 5, 6]
)
plan = core.Plan('collect_data')
plan.AddStep(core.execution_step('collect_data',
[collect_net], num_iter=2))
plan.AddStep(
core.execution_step('collect_data', [collect_net],
num_iter=2)
)
workspace.RunPlan(plan)
reference_result = workspace.FetchBlob('output')
self.assertSequenceEqual(
[item for sublist in reference_result for item in sublist],
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6])
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]
)
plan = core.Plan('collect_data')
plan.AddStep(core.execution_step('collect_data',
[collect_net], num_iter=3))
plan.AddStep(
core.execution_step('collect_data', [collect_net],
num_iter=3)
)
workspace.RunPlan(plan)
reference_result = workspace.FetchBlob('output')
self.assertSequenceEqual(
[item for sublist in reference_result for item in sublist],
[3, 4, 5, 6, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2])
[3, 4, 5, 6, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2]
)
def test_collect_tensor_ops(self):
init_net = core.Net('init_net')
@ -382,9 +579,12 @@ class TestDatasetOps(TestCase):
plan = core.Plan('collect_data')
plan.AddStep(core.execution_step('collect_init', init_net))
plan.AddStep(core.execution_step('collect_data',
[reader_net, collect_net],
num_iter=max_example_to_cover))
plan.AddStep(
core.execution_step(
'collect_data', [reader_net, collect_net],
num_iter=max_example_to_cover
)
)
workspace.RunPlan(plan)
# concat the collected tensors
@ -401,14 +601,19 @@ class TestDatasetOps(TestCase):
# check data
reference_result = workspace.FetchBlob(bconcated_map[blobs[0]])
self.assertEqual(reference_result.shape,
(min(num_to_collect, max_example_to_cover), 2))
self.assertEqual(
reference_result.shape,
(min(num_to_collect, max_example_to_cover), 2)
)
size = workspace.FetchBlob(bsize_map[blobs[0]])
self.assertEqual(tuple(), size.shape)
self.assertEqual(min(num_to_collect, max_example_to_cover), size.item())
hist, _ = np.histogram(reference_result[:, 0], bins=10,
range=(1, max_example_to_cover))
hist, _ = np.histogram(
reference_result[:, 0],
bins=10,
range=(1, max_example_to_cover)
)
print('Sample histogram: {}'.format(hist))
self.assertTrue(all(hist > 0.7 * (num_to_collect / 10)))
@ -416,6 +621,7 @@ class TestDatasetOps(TestCase):
result = workspace.FetchBlob(bconcated_map[blobs[i]])
self.assertEqual(reference_result.tolist(), result.tolist())
if __name__ == "__main__":
import unittest
unittest.main()