Support stateful dataset (#15096)

Summary:
Currently re-implements the dataloader for stateful datasets. Outstanding work:
- Refactor DataLoader and DataLoader2 to have common base classes and only differ in specifi pieces of logic,
- Figure out how to not duplicate the `MapDataset` logic for stateful vs. non-stateful
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15096

Differential Revision: D13522043

Pulled By: goldsborough

fbshipit-source-id: 08e461ca51783047f11facc4d27dfa2e4f1e4c2a
This commit is contained in:
Peter Goldsborough
2018-12-24 06:23:32 -08:00
committed by Facebook Github Bot
parent 8cd917812b
commit ad6799537e
12 changed files with 720 additions and 319 deletions

View File

@ -84,8 +84,8 @@ TEST(DataTest, InfiniteStreamDataset) {
auto data_loader = torch::data::make_data_loader(
std::move(dataset),
kBatchSize,
samplers::StreamSampler(/*epoch_size=*/39));
samplers::StreamSampler(/*epoch_size=*/39),
kBatchSize);
size_t batch_index = 0;
for (auto& batch : *data_loader) {
@ -128,10 +128,10 @@ TEST(DataTest, OrderedSequencerReOrdersValues) {
size_t index = 0;
auto getter = [&v, &index]() { return S{v.at(index++)}; };
// Let's say the sequence number matches for the first one, then it should
// Let's say the sequence number matches for the batch one, then it should
// return immediately.
const auto first = sequencer.next(getter);
ASSERT_EQ(first.value().sequence_number, 0);
const auto batch = sequencer.next(getter);
ASSERT_EQ(batch.value().sequence_number, 0);
ASSERT_EQ(index, 1);
// Now it should call the getter until it gets the next value.
@ -385,9 +385,9 @@ TEST(DataTest, StackTransformWorksForExample) {
auto d = D().map(transforms::Stack<Example<>>());
Example<> first = d.get_batch({0, 1});
ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
ASSERT_TRUE(first.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
Example<> batch = d.get_batch({0, 1});
ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
Example<> second = d.get_batch({2, 3});
ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
@ -398,8 +398,8 @@ TEST(DataTest, StackTransformWorksForTensorExample) {
auto d = datasets::TensorDataset(torch::eye(4))
.map(transforms::Stack<TensorExample>());
TensorExample first = d.get_batch({0, 1});
ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
TensorExample batch = d.get_batch({0, 1});
ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
TensorExample second = d.get_batch({2, 3});
ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
@ -504,7 +504,7 @@ TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
using torch::data::detail::Queue;
// First test: push first and the pop in thread.
// First test: push batch and the pop in thread.
{
Queue<int> queue;
queue.push(1);
@ -513,7 +513,7 @@ TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
ASSERT_EQ(future.get(), 1);
}
// Second test: attempt to pop first (and block), then push.
// Second test: attempt to pop batch (and block), then push.
{
Queue<int> queue;
std::thread thread([&queue] {
@ -544,7 +544,7 @@ TEST(DataTest, DataShuttleCanPushAndPopJob) {
TEST(DataTest, DataShuttleCanPushAndPopResult) {
torch::data::detail::DataShuttle<int, int> shuttle;
// pop_result() will only attempt to pop if there was a push_job() first.
// pop_result() will only attempt to pop if there was a push_job() batch.
shuttle.push_job(1);
shuttle.push_job(2);
@ -672,9 +672,9 @@ struct TestIndexSampler : public samplers::Sampler<TestIndex> {
};
TEST(DataTest, CanUseCustomTypeAsIndexType) {
const size_t kBatchSize = 10;
const int kBatchSize = 10;
auto data_loader = torch::data::make_data_loader(
TestIndexDataset(23), kBatchSize, TestIndexSampler(23));
TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
size_t i = 0;
for (auto batch : *data_loader) {
@ -948,7 +948,7 @@ TEST(DataLoaderTest, RespectsTimeout) {
ASSERT_LT(duration.count(), 1);
}
// https://stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
// stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
struct Barrier {
explicit Barrier(size_t target) : counter_(target) {}
void wait() {
@ -973,12 +973,12 @@ struct Barrier {
// thread (for outside consumption) is not deterministic. Imagine the sampler is
// a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
// will be a single "job". Inside the dataloader, worker threads block until a
// job is available. It is not deterministic which worker thread wakes up first
// job is available. It is not deterministic which worker thread wakes up batch
// to dequeue a particular batch. Further, some worker threads may take longer
// than others to read the data for their index. As such, it could be that
// worker thread 2 finishes before all other threads and returns its batch to
// the main thread. In that case, the dataloader iterator would return the datum
// at index 2 first, and afterwards the datum from whatever thread finishes
// at index 2 batch, and afterwards the datum from whatever thread finishes
// next. As such, the user may see data from indices 2, 0, 3, 1. On another run
// of the same dataloader on the same data, threads may be scheduled differently
// and return in order 0, 2, 3, 1. To force this ordering to deterministically
@ -996,7 +996,7 @@ struct Barrier {
// `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
// has a copy of the dataset, and thus `get_batch()` is called on the
// thread-local copy in each worker. We want to simulate out-of-order completion
// of these threads. For this, we first set a barrier in the `get_batch()`
// of these threads. For this, we batch set a barrier in the `get_batch()`
// method to make sure every worker has some index to fetch assigned. Further,
// each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
// There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
@ -1057,12 +1057,11 @@ struct Dataset : datasets::BatchDataset<Dataset, size_t> {
TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
auto data_loader = torch::data::make_data_loader(
ordering_test::Dataset{},
torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
DataLoaderOptions()
.batch_size(1)
.workers(ordering_test::kNumberOfWorkers)
.enforce_ordering(true),
torch::data::samplers::SequentialSampler(
ordering_test::kNumberOfWorkers));
.enforce_ordering(true));
std::vector<size_t> output;
for (size_t value : *data_loader) {
output.push_back(value);
@ -1104,8 +1103,8 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
}
};
auto data_loader =
torch::data::make_data_loader(D{}, DataLoaderOptions().workers(2));
auto data_loader = torch::data::make_data_loader(
D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
auto iterator = data_loader->begin();
try {
@ -1119,3 +1118,159 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
std::rethrow_exception(e.original_exception), std::invalid_argument);
}
}
TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
struct D : datasets::StatefulDataset<D, int, size_t> {
torch::optional<int> get_batch(size_t) override {
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
return counter++;
}
return torch::nullopt;
}
torch::optional<size_t> size() const override {
return 100;
}
void reset() override {
counter = 0;
}
int counter = 0;
};
auto data_loader = torch::data::make_data_loader(D{});
for (size_t i = 0; i < 10; ++i) {
const auto number_of_iterations =
std::distance(data_loader->begin(), data_loader->end());
ASSERT_EQ(
number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
<< "epoch " << i;
}
for (const int i : *data_loader) {
ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
}
}
TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
const int kNumberOfWorkers = 4;
struct D : datasets::StatefulDataset<D, int, size_t> {
torch::optional<int> get_batch(size_t) override {
std::lock_guard<std::mutex> lock(mutex);
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
return counter++;
}
return torch::nullopt;
}
torch::optional<size_t> size() const override {
return 100;
}
void reset() override {
counter = 0;
}
int counter = 0;
std::mutex mutex;
};
auto data_loader = torch::data::make_data_loader(
torch::data::datasets::make_shared_dataset<D>(),
DataLoaderOptions().workers(kNumberOfWorkers));
for (size_t i = 0; i < 10; ++i) {
const auto number_of_iterations =
std::distance(data_loader->begin(), data_loader->end());
ASSERT_EQ(
number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
<< "epoch " << i;
}
for (const int i : *data_loader) {
ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
}
}
TEST(DataLoaderTest, StatefulDatasetWithMap) {
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
struct D : datasets::StatefulDataset<D, int, size_t> {
torch::optional<int> get_batch(size_t) override {
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
return counter++;
}
return torch::nullopt;
}
torch::optional<size_t> size() const override {
return 100;
}
void reset() override {
counter = 0;
}
int counter = 0;
};
auto data_loader = torch::data::make_data_loader(
D().map(transforms::BatchLambda<int, std::string>(
[](int x) { return std::to_string(x); }))
.map(transforms::BatchLambda<std::string, torch::Tensor>(
[](const std::string& x) {
return torch::tensor(static_cast<int64_t>(std::stoi(x)));
})),
DataLoaderOptions{});
for (size_t i = 0; i < 10; ++i) {
const auto number_of_iterations =
std::distance(data_loader->begin(), data_loader->end());
ASSERT_EQ(
number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
<< "epoch " << i;
}
for (const torch::Tensor& t : *data_loader) {
ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
}
}
TEST(DataLoaderTest, StatefulDatasetWithCollate) {
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
struct D : datasets::StatefulDataset<D> {
torch::optional<std::vector<Example<>>> get_batch(
size_t batch_size) override {
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
counter += batch_size;
std::vector<Example<>> batch(
/*count=*/batch_size,
Example<>{torch::ones(batch_size + 1),
torch::zeros(batch_size - 1)});
return batch;
}
return torch::nullopt;
}
torch::optional<size_t> size() const override {
return 100;
}
void reset() override {
counter = 0;
}
int counter = 0;
};
auto d = D().map(transforms::Stack<Example<>>());
const size_t kBatchSize = 5;
// Notice that the `get_batch()` of the dataset returns a vector<Example>, but
// the `Stack` collation stacks the tensors into one.
torch::optional<Example<>> batch = d.get_batch(kBatchSize);
ASSERT_TRUE(batch.has_value());
ASSERT_EQ(batch->data.size(0), kBatchSize);
ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
ASSERT_EQ(batch->target.size(0), kBatchSize);
ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
}