mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
caffe2: datasets pack/unpack
Summary: Two new operators to pack and unpack a dataset. This is so that we can re-use other operators that do not understand the schema format. The immediate use-case is to use it with a partition operator. Packing works by splitting the input into separate tensors, putting them in a vector and wrapping in a shared_ptr (as opposed to a unique_ptr, so we can copy). Unpack takes the packed input and concatenates it back to the original. I also had a gard time understanding the iteration, so I created a TreeWalker that just hides the complexity of operating with all the arrays and makes the short functions for a given purpose that at least for me are easier to understand. Reviewed By: dzhulgakov Differential Revision: D4918002 fbshipit-source-id: ecbf9196ed25e886a94383961176b8c84dde2d2f
This commit is contained in:
committed by
Facebook Github Bot
parent
9cb901caf0
commit
902409be56
@ -274,6 +274,264 @@ class CheckDatasetConsistencyOp : public Operator<CPUContext> {
|
||||
TreeIterator iterator_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Simple wrapper class allowing an easy traversal of the tensors representing
|
||||
* the hirerarchical structure.
|
||||
*/
|
||||
class TreeWalker {
|
||||
public:
|
||||
TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor)
|
||||
: inputs_(inputs), cursor_(cursor), sizes_(cursor.it.numOffsetFields()) {
|
||||
if (cursor.offsets.empty()) {
|
||||
cursor.offsets.assign(cursor.it.numOffsetFields(), 0);
|
||||
}
|
||||
|
||||
for (int fieldId = 0; fieldId < cursor_.it.fields().size(); ++fieldId) {
|
||||
fields_.emplace_back(*this, fieldId);
|
||||
}
|
||||
|
||||
gatherLengthData();
|
||||
|
||||
gatherSizeLimits();
|
||||
|
||||
// The invariant we hold is that we are always one step ahead
|
||||
advance();
|
||||
}
|
||||
|
||||
// Returns the number of records in a dataset
|
||||
inline TOffset size() const {
|
||||
return limits_.at(0);
|
||||
}
|
||||
|
||||
void advance() {
|
||||
prevOffsets_ = cursor_.offsets;
|
||||
cursor_.it.advance(lengths_, cursor_.offsets, sizes_, limits_, 1);
|
||||
}
|
||||
|
||||
private:
|
||||
inline const TensorCPU& input(int32_t idx) const {
|
||||
return inputs_[idx]->Get<TensorCPU>();
|
||||
}
|
||||
|
||||
// TODO: Change to fieldDesc
|
||||
inline const TreeIterator::FieldDesc& field(int idx) const {
|
||||
return cursor_.it.fields().at(idx);
|
||||
}
|
||||
|
||||
inline int lengthIdx(int fieldId) const {
|
||||
return field(fieldId).lengthFieldId + 1;
|
||||
}
|
||||
|
||||
inline TOffset offset(int fieldId) const {
|
||||
return prevOffsets_[lengthIdx(fieldId)];
|
||||
}
|
||||
|
||||
std::vector<TIndex> fieldDim(int fieldId) const {
|
||||
auto tensorDim = input(fieldId).dims();
|
||||
tensorDim[0] = sizes_[lengthIdx(fieldId)];
|
||||
return tensorDim;
|
||||
}
|
||||
|
||||
void* fieldPtr(int fieldId) const {
|
||||
auto& in = input(fieldId);
|
||||
return (char*)in.raw_data() +
|
||||
offset(fieldId) * in.size_from_dim(1) * in.meta().itemsize();
|
||||
}
|
||||
|
||||
public:
|
||||
// Simple Proxy class to expose nicer API for field access
|
||||
class Field {
|
||||
public:
|
||||
Field(TreeWalker& walker, int fieldId)
|
||||
: walker_(walker), fieldId_(fieldId) {}
|
||||
|
||||
inline std::vector<TIndex> dim() const {
|
||||
return walker_.fieldDim(fieldId_);
|
||||
}
|
||||
|
||||
inline const TypeMeta& meta() const {
|
||||
return walker_.input(fieldId_).meta();
|
||||
}
|
||||
|
||||
inline void* ptr() const {
|
||||
return walker_.fieldPtr(fieldId_);
|
||||
}
|
||||
|
||||
private:
|
||||
const TreeWalker& walker_;
|
||||
const int fieldId_;
|
||||
};
|
||||
|
||||
// Notice that a reference is returned. If advance() is called the fields will
|
||||
// be updated to represent the new state.
|
||||
inline const std::vector<Field>& fields() const {
|
||||
return fields_;
|
||||
}
|
||||
|
||||
private:
|
||||
void gatherLengthData() {
|
||||
static const TLength lenZero = 0;
|
||||
lengths_.resize(cursor_.it.numLengthFields());
|
||||
for (int i = 0; i < lengths_.size(); ++i) {
|
||||
auto& in = input(cursor_.it.lengthField(i).id);
|
||||
if (in.size() > 0) {
|
||||
lengths_[i] = in.data<int>();
|
||||
} else {
|
||||
lengths_[i] = &lenZero;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gatherSizeLimits() {
|
||||
limits_.assign(sizes_.size(), std::numeric_limits<TOffset>::max());
|
||||
for (auto fieldId = 0; fieldId < cursor_.it.fields().size(); ++fieldId) {
|
||||
auto lengthFieldIdx = lengthIdx(fieldId);
|
||||
limits_[lengthFieldIdx] =
|
||||
std::min(limits_[lengthFieldIdx], (TOffset)input(fieldId).dims()[0]);
|
||||
}
|
||||
}
|
||||
|
||||
const vector<const Blob*>& inputs_;
|
||||
TreeCursor& cursor_;
|
||||
std::vector<Field> fields_;
|
||||
|
||||
std::vector<const TLength*> lengths_;
|
||||
std::vector<TOffset> limits_;
|
||||
std::vector<TOffset> sizes_;
|
||||
std::vector<TOffset> offsets_;
|
||||
std::vector<TOffset> prevOffsets_;
|
||||
};
|
||||
|
||||
using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
|
||||
|
||||
class PackRecordsOp : public Operator<CPUContext> {
|
||||
public:
|
||||
PackRecordsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws),
|
||||
fields_(OperatorBase::GetRepeatedArgument<std::string>("fields")) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
// There should be one input per field
|
||||
CAFFE_ENFORCE_EQ(InputSize(), fields_.size());
|
||||
CAFFE_ENFORCE_EQ(OutputSize(), 1);
|
||||
|
||||
TreeCursor cursor((TreeIterator(fields_)));
|
||||
|
||||
TreeWalker walker(Inputs(), cursor);
|
||||
|
||||
Output(0)->Resize(walker.size());
|
||||
|
||||
// Output(0)->raw_mutable_data(TypeMeta::Make<SharedTensorVectorPtr>()));
|
||||
auto* dst = Output(0)->mutable_data<SharedTensorVectorPtr>();
|
||||
|
||||
for (int batchId = 0; batchId < walker.size(); ++batchId) {
|
||||
dst[batchId] = std::make_shared<std::vector<TensorCPU>>();
|
||||
dst[batchId]->reserve(walker.fields().size());
|
||||
|
||||
for (const auto& field : walker.fields()) {
|
||||
dst[batchId]->emplace_back(field.dim());
|
||||
auto& tensor = dst[batchId]->back();
|
||||
context_.template CopyItems<CPUContext, CPUContext>(
|
||||
field.meta(),
|
||||
tensor.size(),
|
||||
field.ptr() /* src */,
|
||||
tensor.raw_mutable_data(field.meta()) /* dst */);
|
||||
}
|
||||
|
||||
walker.advance();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::string> fields_;
|
||||
};
|
||||
|
||||
class UnPackRecordsOp : public Operator<CPUContext> {
|
||||
public:
|
||||
UnPackRecordsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws),
|
||||
fields_(OperatorBase::GetRepeatedArgument<std::string>("fields")) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto* inputs = Input(0).template data<SharedTensorVectorPtr>();
|
||||
const auto numRows = Input(0).size();
|
||||
|
||||
CAFFE_ENFORCE_GE(numRows, 0);
|
||||
|
||||
if (numRows == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto& inputZero = inputs[0];
|
||||
CAFFE_ENFORCE(inputZero);
|
||||
|
||||
const auto numTensors = inputZero->size();
|
||||
|
||||
CAFFE_ENFORCE_EQ(numTensors, fields_.size());
|
||||
CAFFE_ENFORCE_EQ(numTensors, OutputSize());
|
||||
|
||||
// Precomputer the output sizes to avoid resizing
|
||||
std::vector<std::vector<TIndex>> outputDims(numTensors);
|
||||
|
||||
for (int i = 0; i < numTensors; ++i) {
|
||||
outputDims[i] = inputs[0]->at(i).dims();
|
||||
outputDims[i][0] = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < numRows; ++i) {
|
||||
CAFFE_ENFORCE(inputs[i]);
|
||||
for (int j = 0; j < inputs[i]->size(); ++j) {
|
||||
const auto& input = inputs[i]->at(j);
|
||||
const auto& inputZeroTensor = inputZero->at(j);
|
||||
|
||||
// Checks to ensure that dimensions/sizes match
|
||||
CAFFE_ENFORCE_EQ(inputZeroTensor.ndim(), input.ndim());
|
||||
CAFFE_ENFORCE(inputZeroTensor.meta() == input.meta());
|
||||
// We look from first dimension, because we concat on the first.
|
||||
for (int k = 1; k < input.ndim(); ++k) {
|
||||
CAFFE_ENFORCE_EQ(input.dims()[k], inputZeroTensor.dims()[k]);
|
||||
}
|
||||
|
||||
outputDims[j][0] += input.dim(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Resize to the final output size
|
||||
std::vector<void*> destinations(numTensors);
|
||||
for (int i = 0; i < numTensors; ++i) {
|
||||
Output(i)->Resize(outputDims[i]);
|
||||
destinations[i] = Output(i)->raw_mutable_data(inputZero->at(i).meta());
|
||||
}
|
||||
|
||||
for (int i = 0; i < numRows; ++i) {
|
||||
for (int j = 0; j < numTensors; ++j) {
|
||||
const auto& input = inputs[i]->at(j);
|
||||
// Skip empty tensors
|
||||
if (input.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
context_.CopyItems<CPUContext, CPUContext>(
|
||||
inputZero->at(j).meta(),
|
||||
input.size(),
|
||||
input.raw_data() /* src */,
|
||||
destinations[j] /* dst */
|
||||
);
|
||||
|
||||
destinations[j] =
|
||||
(char*)destinations[j] + input.size() * inputZero->at(j).itemsize();
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::string> fields_;
|
||||
};
|
||||
|
||||
class ReadNextBatchOp : public Operator<CPUContext> {
|
||||
public:
|
||||
ReadNextBatchOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
@ -803,6 +1061,8 @@ REGISTER_CPU_OPERATOR(CreateTensorVector, CreateTensorVectorOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(TensorVectorSize, TensorVectorSizeOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(ConcatTensorVector, ConcatTensorVectorOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(CollectTensor, CollectTensorOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(PackRecords, PackRecordsOp);
|
||||
REGISTER_CPU_OPERATOR(UnPackRecords, UnPackRecordsOp);
|
||||
|
||||
OPERATOR_SCHEMA(CreateTreeCursor)
|
||||
.NumInputs(0)
|
||||
@ -1048,6 +1308,34 @@ output vectors.
|
||||
)DOC")
|
||||
.Arg("num_to_collect", "The max number of tensors to collect");
|
||||
|
||||
OPERATOR_SCHEMA(PackRecords)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Given a dataset under a schema specified by the `fields` argument will pack all the input tensors into one,
|
||||
where each tensor element represents a row of data (batch of size 1). This format allows easier use with the rest of Caffe2 operators.
|
||||
)DOC")
|
||||
.Arg(
|
||||
"fields",
|
||||
"List of strings representing the string names in the format"
|
||||
"specified in the doc for CreateTreeCursor.")
|
||||
.Output(
|
||||
0,
|
||||
"tensor",
|
||||
"One dimensional tensor having a complex type of SharedTensorVectorPtr. In order to reverse it back to the original input it has to be inserted into UnPackRecordsOp.");
|
||||
|
||||
OPERATOR_SCHEMA(UnPackRecords)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1, INT_MAX)
|
||||
.SetDoc(R"DOC(
|
||||
Given a packed dataset (packed by the PackRecordsOp) and the `fields` argument describing the datasets schema returns the original dataset format.
|
||||
Number of returned tensors is equal to the number of fields in the `fields` argument.
|
||||
)DOC")
|
||||
.Arg(
|
||||
"fields",
|
||||
"List of strings representing the string names in the format"
|
||||
"specified in the doc for CreateTreeCursor.");
|
||||
|
||||
SHOULD_NOT_DO_GRADIENT(CreateTreeCursor);
|
||||
SHOULD_NOT_DO_GRADIENT(ResetCursor);
|
||||
SHOULD_NOT_DO_GRADIENT(ReadNextBatch);
|
||||
@ -1060,9 +1348,12 @@ SHOULD_NOT_DO_GRADIENT(CreateTensorVector);
|
||||
SHOULD_NOT_DO_GRADIENT(TensorVectorSize);
|
||||
SHOULD_NOT_DO_GRADIENT(ConcatTensorVector);
|
||||
SHOULD_NOT_DO_GRADIENT(CollectTensor);
|
||||
SHOULD_NOT_DO_GRADIENT(UnPack);
|
||||
SHOULD_NOT_DO_GRADIENT(Pack);
|
||||
} // namespace
|
||||
CAFFE_KNOWN_TYPE(std::unique_ptr<TreeCursor>);
|
||||
CAFFE_KNOWN_TYPE(TensorVectorPtr<CPUContext>);
|
||||
CAFFE_KNOWN_TYPE(SharedTensorVectorPtr);
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -6,17 +6,27 @@ import numpy as np
|
||||
from caffe2.python import core, workspace, dataset
|
||||
from caffe2.python.dataset import Const
|
||||
from caffe2.python.schema import (
|
||||
List, Field, Struct, Scalar, Map,
|
||||
from_blob_list, FetchRecord, NewRecord, FeedRecord)
|
||||
List, Field, Struct, Scalar, Map, from_blob_list, FetchRecord, NewRecord,
|
||||
FeedRecord
|
||||
)
|
||||
from caffe2.python.test_util import TestCase
|
||||
|
||||
import numpy.testing as npt
|
||||
|
||||
import string
|
||||
|
||||
from hypothesis import given
|
||||
import hypothesis.strategies as st
|
||||
|
||||
|
||||
def _assert_arrays_equal(actual, ref, err_msg):
|
||||
if ref.dtype.kind in ('S', 'O'):
|
||||
np.testing.assert_array_equal(actual, ref, err_msg=err_msg)
|
||||
else:
|
||||
np.testing.assert_allclose(
|
||||
actual, ref, atol=1e-4, rtol=1e-4, err_msg=err_msg)
|
||||
actual, ref, atol=1e-4,
|
||||
rtol=1e-4, err_msg=err_msg
|
||||
)
|
||||
|
||||
|
||||
def _assert_records_equal(actual, ref):
|
||||
@ -24,13 +34,169 @@ def _assert_records_equal(actual, ref):
|
||||
assert isinstance(ref, Field)
|
||||
b1 = actual.field_blobs()
|
||||
b2 = ref.field_blobs()
|
||||
assert(len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % (
|
||||
len(b1), len(b2))
|
||||
assert (len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % (
|
||||
len(b1), len(b2)
|
||||
)
|
||||
for name, d1, d2 in zip(ref.field_names(), b1, b2):
|
||||
_assert_arrays_equal(d1, d2, err_msg='Mismatch in field %s.' % name)
|
||||
|
||||
|
||||
@st.composite
|
||||
def _sparse_features_map(draw, num_records, **kwargs):
|
||||
sparse_maps_lengths = draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=10),
|
||||
min_size=num_records,
|
||||
max_size=num_records
|
||||
)
|
||||
)
|
||||
|
||||
sparse_maps_total_length = sum(sparse_maps_lengths)
|
||||
|
||||
sparse_keys = draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=100),
|
||||
min_size=sparse_maps_total_length,
|
||||
max_size=sparse_maps_total_length,
|
||||
unique=True
|
||||
)
|
||||
)
|
||||
|
||||
sparse_values_lengths = draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=10),
|
||||
min_size=sparse_maps_total_length,
|
||||
max_size=sparse_maps_total_length
|
||||
)
|
||||
)
|
||||
|
||||
total_sparse_values_lengths = sum(sparse_values_lengths)
|
||||
|
||||
sparse_values = draw(
|
||||
# max_value is max int64
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=9223372036854775807),
|
||||
min_size=total_sparse_values_lengths,
|
||||
max_size=total_sparse_values_lengths
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
sparse_maps_lengths,
|
||||
sparse_keys,
|
||||
sparse_values_lengths,
|
||||
sparse_values,
|
||||
]
|
||||
|
||||
|
||||
@st.composite
|
||||
def _dense_features_map(draw, num_records, **kwargs):
|
||||
float_lengths = draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=10),
|
||||
min_size=num_records,
|
||||
max_size=num_records
|
||||
)
|
||||
)
|
||||
|
||||
total_length = sum(float_lengths)
|
||||
|
||||
float_keys = draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=100),
|
||||
min_size=total_length,
|
||||
max_size=total_length,
|
||||
unique=True
|
||||
)
|
||||
)
|
||||
|
||||
float_values = draw(
|
||||
st.lists(st.floats(),
|
||||
min_size=total_length,
|
||||
max_size=total_length)
|
||||
)
|
||||
|
||||
return [float_lengths, float_keys, float_values]
|
||||
|
||||
|
||||
@st.composite
|
||||
def _dataset(draw, min_elements=3, max_elements=10, **kwargs):
|
||||
schema = Struct(
|
||||
# Dense Features Map
|
||||
('floats', Map(
|
||||
Scalar(np.int32), Scalar(np.float32)
|
||||
)),
|
||||
# Sparse Features Map
|
||||
('int_lists', Map(
|
||||
Scalar(np.int32),
|
||||
List(Scalar(np.int64)),
|
||||
)),
|
||||
# Complex Type
|
||||
('text', Scalar(str)),
|
||||
)
|
||||
|
||||
num_records = draw(
|
||||
st.integers(min_value=min_elements,
|
||||
max_value=max_elements)
|
||||
)
|
||||
|
||||
raw_dense_features_map_contents = draw(_dense_features_map(num_records))
|
||||
|
||||
raw_sparse_features_map_contents = draw(_sparse_features_map(num_records))
|
||||
|
||||
raw_text_contents = [
|
||||
draw(
|
||||
st.lists(
|
||||
st.text(alphabet=string.ascii_lowercase),
|
||||
min_size=num_records,
|
||||
max_size=num_records
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Concatenate all raw contents to a single one
|
||||
contents_raw = raw_dense_features_map_contents + raw_sparse_features_map_contents + raw_text_contents
|
||||
|
||||
contents = from_blob_list(schema, contents_raw)
|
||||
|
||||
return (schema, contents, num_records)
|
||||
|
||||
|
||||
class TestDatasetOps(TestCase):
|
||||
@given(_dataset())
|
||||
def test_pack_unpack(self, input):
|
||||
"""
|
||||
Tests if packing and unpacking of the whole dataset is an identity.
|
||||
"""
|
||||
(schema, contents, num_records) = input
|
||||
|
||||
dataset_fields = schema.field_names()
|
||||
|
||||
net = core.Net('pack_unpack_net')
|
||||
|
||||
batch = NewRecord(net, contents)
|
||||
FeedRecord(batch, contents)
|
||||
|
||||
packed = net.PackRecords(
|
||||
batch.field_blobs(), 1,
|
||||
fields=dataset_fields
|
||||
)
|
||||
|
||||
unpacked = packed.UnPackRecords(
|
||||
[], len(dataset_fields),
|
||||
fields=dataset_fields
|
||||
)
|
||||
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
for initial_tensor, unpacked_tensor in zip(
|
||||
batch.field_blobs(), unpacked
|
||||
):
|
||||
npt.assert_array_equal(
|
||||
workspace.FetchBlob(initial_tensor),
|
||||
workspace.FetchBlob(unpacked_tensor)
|
||||
)
|
||||
|
||||
def test_dataset_ops(self):
|
||||
"""
|
||||
1. Defining the schema of our dataset.
|
||||
@ -42,30 +208,34 @@ class TestDatasetOps(TestCase):
|
||||
('dense', Scalar((np.float32, 3))),
|
||||
# could represent a feature map from feature ID to float value
|
||||
('floats', Map(
|
||||
Scalar(np.int32),
|
||||
Scalar(np.float32))),
|
||||
Scalar(np.int32), Scalar(np.float32)
|
||||
)),
|
||||
# could represent a multi-valued categorical feature map
|
||||
('int_lists', Map(
|
||||
Scalar(np.int32),
|
||||
List(Scalar(np.int64)),
|
||||
)),
|
||||
# could represent a multi-valued, weighted categorical feature map
|
||||
('id_score_pairs', Map(
|
||||
Scalar(np.int32),
|
||||
Map(
|
||||
Scalar(np.int64),
|
||||
Scalar(np.float32),
|
||||
keys_name='ids',
|
||||
values_name='scores'),
|
||||
)),
|
||||
(
|
||||
'id_score_pairs', Map(
|
||||
Scalar(np.int32),
|
||||
Map(
|
||||
Scalar(np.int64),
|
||||
Scalar(np.float32),
|
||||
keys_name='ids',
|
||||
values_name='scores'
|
||||
),
|
||||
)
|
||||
),
|
||||
# additional scalar information
|
||||
('metadata', Struct(
|
||||
('user_id', Scalar(np.int64)),
|
||||
('user_embed', Scalar((np.float32, 2))),
|
||||
('query', Scalar(str)),
|
||||
)),
|
||||
(
|
||||
'metadata', Struct(
|
||||
('user_id', Scalar(np.int64)),
|
||||
('user_embed', Scalar((np.float32, 2))),
|
||||
('query', Scalar(str)),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
"""
|
||||
This is what the flattened fields for this schema look like, along
|
||||
with its type. Each one of these fields will be stored, read and
|
||||
@ -90,13 +260,11 @@ class TestDatasetOps(TestCase):
|
||||
('metadata:query', str),
|
||||
]
|
||||
zipped = zip(
|
||||
expected_fields,
|
||||
schema.field_names(),
|
||||
schema.field_types())
|
||||
expected_fields, schema.field_names(), schema.field_types()
|
||||
)
|
||||
for (ref_name, ref_type), name, dtype in zipped:
|
||||
self.assertEquals(ref_name, name)
|
||||
self.assertEquals(np.dtype(ref_type), dtype)
|
||||
|
||||
"""
|
||||
2. The contents of our dataset.
|
||||
|
||||
@ -129,7 +297,6 @@ class TestDatasetOps(TestCase):
|
||||
]
|
||||
# convert the above content to ndarrays, checking against the schema
|
||||
contents = from_blob_list(schema, contents_raw)
|
||||
|
||||
"""
|
||||
3. Creating and appending to the dataset.
|
||||
We first create an empty dataset with the given schema.
|
||||
@ -145,7 +312,6 @@ class TestDatasetOps(TestCase):
|
||||
writer = ds.writer(init_net=net)
|
||||
writer.write_record(net, content_blobs)
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
"""
|
||||
4. Iterating through the dataset contents.
|
||||
|
||||
@ -155,32 +321,63 @@ class TestDatasetOps(TestCase):
|
||||
entries_raw = [
|
||||
(
|
||||
[[1.1, 1.2, 1.3]], # dense
|
||||
[1], [11], [1.1], # floats
|
||||
[2], [11, 12], [2, 4], [111, 112, 121, 122, 123, 124], # intlst
|
||||
[1], [11], [1], [111], [11.1], # id score pairs
|
||||
[123], [[0.2, 0.8]], ['dog posts'], # metadata
|
||||
[1],
|
||||
[11],
|
||||
[1.1], # floats
|
||||
[2],
|
||||
[11, 12],
|
||||
[2, 4],
|
||||
[111, 112, 121, 122, 123, 124], # intlst
|
||||
[1],
|
||||
[11],
|
||||
[1],
|
||||
[111],
|
||||
[11.1], # id score pairs
|
||||
[123],
|
||||
[[0.2, 0.8]],
|
||||
['dog posts'], # metadata
|
||||
),
|
||||
(
|
||||
[[2.1, 2.2, 2.3]], # dense
|
||||
[2], [21, 22], [2.1, 2.2], # floats
|
||||
[0], [], [], [], # int list
|
||||
[2], [21, 22], [1, 2], [211, 221, 222], [21.1, 22.1, 22.2],
|
||||
[234], [[0.5, 0.5]], ['friends who like to'], # metadata
|
||||
[2],
|
||||
[21, 22],
|
||||
[2.1, 2.2], # floats
|
||||
[0],
|
||||
[],
|
||||
[],
|
||||
[], # int list
|
||||
[2],
|
||||
[21, 22],
|
||||
[1, 2],
|
||||
[211, 221, 222],
|
||||
[21.1, 22.1, 22.2],
|
||||
[234],
|
||||
[[0.5, 0.5]],
|
||||
['friends who like to'], # metadata
|
||||
),
|
||||
(
|
||||
[[3.1, 3.2, 3.3]], # dense
|
||||
[3], [31, 32, 33], [3.1, 3.2, 3.3], # floats
|
||||
[1], [31], [3], [311, 312, 313], # int lst
|
||||
[2], [31, 32], [2, 3], [311, 312, 321, 322, 323],
|
||||
[3],
|
||||
[31, 32, 33],
|
||||
[3.1, 3.2, 3.3], # floats
|
||||
[1],
|
||||
[31],
|
||||
[3],
|
||||
[311, 312, 313], # int lst
|
||||
[2],
|
||||
[31, 32],
|
||||
[2, 3],
|
||||
[311, 312, 321, 322, 323],
|
||||
[31.1, 31.2, 32.1, 32.2, 32.3], # id score list
|
||||
[456], [[0.7, 0.3]], ['posts about ca'], # metadata
|
||||
[456],
|
||||
[[0.7, 0.3]],
|
||||
['posts about ca'], # metadata
|
||||
),
|
||||
# after the end of the dataset, we will keep getting empty vectors
|
||||
([],) * 16,
|
||||
([],) * 16,
|
||||
([], ) * 16,
|
||||
([], ) * 16,
|
||||
]
|
||||
entries = [from_blob_list(schema, e) for e in entries_raw]
|
||||
|
||||
"""
|
||||
Let's go ahead and create the reading nets.
|
||||
We will run `read` net multiple times and assert that we are reading the
|
||||
@ -198,7 +395,6 @@ class TestDatasetOps(TestCase):
|
||||
workspace.RunNet(str(read_next_net))
|
||||
actual = FetchRecord(batch)
|
||||
_assert_records_equal(actual, entry)
|
||||
|
||||
"""
|
||||
5. Reading/writing in a single plan
|
||||
|
||||
@ -212,7 +408,6 @@ class TestDatasetOps(TestCase):
|
||||
reset_net = core.Net('reset_net')
|
||||
reader.reset(reset_net)
|
||||
read_step, batch = reader.execution_step()
|
||||
|
||||
""" We will add the line number * 1000 to the feature ids. """
|
||||
process_net = core.Net('process')
|
||||
line_no = Const(process_net, 0, dtype=np.int32)
|
||||
@ -221,7 +416,6 @@ class TestDatasetOps(TestCase):
|
||||
field = batch.floats.keys.get()
|
||||
process_net.Print(field, [])
|
||||
process_net.Add([field, line_no], field, broadcast=1, axis=0)
|
||||
|
||||
""" Lets create a second dataset and append to it. """
|
||||
ds2 = dataset.Dataset(schema, name='dataset2')
|
||||
ds2.init_empty(reset_net)
|
||||
@ -231,14 +425,12 @@ class TestDatasetOps(TestCase):
|
||||
# generality of the example
|
||||
commit_net = core.Net('commit')
|
||||
writer.commit(commit_net)
|
||||
|
||||
""" Time to create and run a plan which will do the processing """
|
||||
plan = core.Plan('process')
|
||||
plan.AddStep(core.execution_step('reset', reset_net))
|
||||
plan.AddStep(read_step.AddNet(process_net))
|
||||
plan.AddStep(core.execution_step('commit', commit_net))
|
||||
workspace.RunPlan(plan)
|
||||
|
||||
"""
|
||||
Now we should have dataset2 populated.
|
||||
"""
|
||||
@ -246,7 +438,6 @@ class TestDatasetOps(TestCase):
|
||||
field = ds2_data.floats.keys
|
||||
field.set(blob=field.get() - [1000, 2000, 2000, 3000, 3000, 3000])
|
||||
_assert_records_equal(contents, ds2_data)
|
||||
|
||||
"""
|
||||
6. Slicing a dataset
|
||||
|
||||
@ -256,7 +447,6 @@ class TestDatasetOps(TestCase):
|
||||
subschema = Struct(('top_level', schema.int_lists.values))
|
||||
int_list_contents = contents.int_lists.values.field_names()
|
||||
self.assertEquals(len(subschema.field_names()), len(int_list_contents))
|
||||
|
||||
"""
|
||||
7. Random Access a dataset
|
||||
|
||||
@ -282,8 +472,6 @@ class TestDatasetOps(TestCase):
|
||||
workspace.RunNet(str(read_next_net))
|
||||
actual = FetchRecord(batch)
|
||||
_assert_records_equal(actual, entry)
|
||||
|
||||
|
||||
"""
|
||||
8. Sort and shuffle a dataset
|
||||
|
||||
@ -328,31 +516,40 @@ class TestDatasetOps(TestCase):
|
||||
num_to_collect=7,
|
||||
)
|
||||
plan = core.Plan('collect_data')
|
||||
plan.AddStep(core.execution_step('collect_data',
|
||||
[collect_net], num_iter=1))
|
||||
plan.AddStep(
|
||||
core.execution_step('collect_data', [collect_net],
|
||||
num_iter=1)
|
||||
)
|
||||
workspace.RunPlan(plan)
|
||||
reference_result = workspace.FetchBlob('output')
|
||||
self.assertSequenceEqual(
|
||||
[item for sublist in reference_result for item in sublist],
|
||||
[1, 2, 3, 4, 5, 6])
|
||||
[1, 2, 3, 4, 5, 6]
|
||||
)
|
||||
|
||||
plan = core.Plan('collect_data')
|
||||
plan.AddStep(core.execution_step('collect_data',
|
||||
[collect_net], num_iter=2))
|
||||
plan.AddStep(
|
||||
core.execution_step('collect_data', [collect_net],
|
||||
num_iter=2)
|
||||
)
|
||||
workspace.RunPlan(plan)
|
||||
reference_result = workspace.FetchBlob('output')
|
||||
self.assertSequenceEqual(
|
||||
[item for sublist in reference_result for item in sublist],
|
||||
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6])
|
||||
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]
|
||||
)
|
||||
|
||||
plan = core.Plan('collect_data')
|
||||
plan.AddStep(core.execution_step('collect_data',
|
||||
[collect_net], num_iter=3))
|
||||
plan.AddStep(
|
||||
core.execution_step('collect_data', [collect_net],
|
||||
num_iter=3)
|
||||
)
|
||||
workspace.RunPlan(plan)
|
||||
reference_result = workspace.FetchBlob('output')
|
||||
self.assertSequenceEqual(
|
||||
[item for sublist in reference_result for item in sublist],
|
||||
[3, 4, 5, 6, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2])
|
||||
[3, 4, 5, 6, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2]
|
||||
)
|
||||
|
||||
def test_collect_tensor_ops(self):
|
||||
init_net = core.Net('init_net')
|
||||
@ -382,9 +579,12 @@ class TestDatasetOps(TestCase):
|
||||
|
||||
plan = core.Plan('collect_data')
|
||||
plan.AddStep(core.execution_step('collect_init', init_net))
|
||||
plan.AddStep(core.execution_step('collect_data',
|
||||
[reader_net, collect_net],
|
||||
num_iter=max_example_to_cover))
|
||||
plan.AddStep(
|
||||
core.execution_step(
|
||||
'collect_data', [reader_net, collect_net],
|
||||
num_iter=max_example_to_cover
|
||||
)
|
||||
)
|
||||
workspace.RunPlan(plan)
|
||||
|
||||
# concat the collected tensors
|
||||
@ -401,14 +601,19 @@ class TestDatasetOps(TestCase):
|
||||
|
||||
# check data
|
||||
reference_result = workspace.FetchBlob(bconcated_map[blobs[0]])
|
||||
self.assertEqual(reference_result.shape,
|
||||
(min(num_to_collect, max_example_to_cover), 2))
|
||||
self.assertEqual(
|
||||
reference_result.shape,
|
||||
(min(num_to_collect, max_example_to_cover), 2)
|
||||
)
|
||||
size = workspace.FetchBlob(bsize_map[blobs[0]])
|
||||
self.assertEqual(tuple(), size.shape)
|
||||
self.assertEqual(min(num_to_collect, max_example_to_cover), size.item())
|
||||
|
||||
hist, _ = np.histogram(reference_result[:, 0], bins=10,
|
||||
range=(1, max_example_to_cover))
|
||||
hist, _ = np.histogram(
|
||||
reference_result[:, 0],
|
||||
bins=10,
|
||||
range=(1, max_example_to_cover)
|
||||
)
|
||||
print('Sample histogram: {}'.format(hist))
|
||||
|
||||
self.assertTrue(all(hist > 0.7 * (num_to_collect / 10)))
|
||||
@ -416,6 +621,7 @@ class TestDatasetOps(TestCase):
|
||||
result = workspace.FetchBlob(bconcated_map[blobs[i]])
|
||||
self.assertEqual(reference_result.tolist(), result.tolist())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user