mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Support stateful dataset (#15096)
Summary: Currently re-implements the dataloader for stateful datasets. Outstanding work: - Refactor DataLoader and DataLoader2 to have common base classes and only differ in specifi pieces of logic, - Figure out how to not duplicate the `MapDataset` logic for stateful vs. non-stateful Pull Request resolved: https://github.com/pytorch/pytorch/pull/15096 Differential Revision: D13522043 Pulled By: goldsborough fbshipit-source-id: 08e461ca51783047f11facc4d27dfa2e4f1e4c2a
This commit is contained in:
committed by
Facebook Github Bot
parent
8cd917812b
commit
ad6799537e
@ -84,8 +84,8 @@ TEST(DataTest, InfiniteStreamDataset) {
|
||||
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
std::move(dataset),
|
||||
kBatchSize,
|
||||
samplers::StreamSampler(/*epoch_size=*/39));
|
||||
samplers::StreamSampler(/*epoch_size=*/39),
|
||||
kBatchSize);
|
||||
|
||||
size_t batch_index = 0;
|
||||
for (auto& batch : *data_loader) {
|
||||
@ -128,10 +128,10 @@ TEST(DataTest, OrderedSequencerReOrdersValues) {
|
||||
size_t index = 0;
|
||||
auto getter = [&v, &index]() { return S{v.at(index++)}; };
|
||||
|
||||
// Let's say the sequence number matches for the first one, then it should
|
||||
// Let's say the sequence number matches for the batch one, then it should
|
||||
// return immediately.
|
||||
const auto first = sequencer.next(getter);
|
||||
ASSERT_EQ(first.value().sequence_number, 0);
|
||||
const auto batch = sequencer.next(getter);
|
||||
ASSERT_EQ(batch.value().sequence_number, 0);
|
||||
ASSERT_EQ(index, 1);
|
||||
|
||||
// Now it should call the getter until it gets the next value.
|
||||
@ -385,9 +385,9 @@ TEST(DataTest, StackTransformWorksForExample) {
|
||||
|
||||
auto d = D().map(transforms::Stack<Example<>>());
|
||||
|
||||
Example<> first = d.get_batch({0, 1});
|
||||
ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
|
||||
ASSERT_TRUE(first.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
|
||||
Example<> batch = d.get_batch({0, 1});
|
||||
ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
|
||||
ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
|
||||
|
||||
Example<> second = d.get_batch({2, 3});
|
||||
ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
|
||||
@ -398,8 +398,8 @@ TEST(DataTest, StackTransformWorksForTensorExample) {
|
||||
auto d = datasets::TensorDataset(torch::eye(4))
|
||||
.map(transforms::Stack<TensorExample>());
|
||||
|
||||
TensorExample first = d.get_batch({0, 1});
|
||||
ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
|
||||
TensorExample batch = d.get_batch({0, 1});
|
||||
ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
|
||||
|
||||
TensorExample second = d.get_batch({2, 3});
|
||||
ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
|
||||
@ -504,7 +504,7 @@ TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
|
||||
TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
|
||||
using torch::data::detail::Queue;
|
||||
|
||||
// First test: push first and the pop in thread.
|
||||
// First test: push batch and the pop in thread.
|
||||
{
|
||||
Queue<int> queue;
|
||||
queue.push(1);
|
||||
@ -513,7 +513,7 @@ TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
|
||||
ASSERT_EQ(future.get(), 1);
|
||||
}
|
||||
|
||||
// Second test: attempt to pop first (and block), then push.
|
||||
// Second test: attempt to pop batch (and block), then push.
|
||||
{
|
||||
Queue<int> queue;
|
||||
std::thread thread([&queue] {
|
||||
@ -544,7 +544,7 @@ TEST(DataTest, DataShuttleCanPushAndPopJob) {
|
||||
|
||||
TEST(DataTest, DataShuttleCanPushAndPopResult) {
|
||||
torch::data::detail::DataShuttle<int, int> shuttle;
|
||||
// pop_result() will only attempt to pop if there was a push_job() first.
|
||||
// pop_result() will only attempt to pop if there was a push_job() batch.
|
||||
shuttle.push_job(1);
|
||||
shuttle.push_job(2);
|
||||
|
||||
@ -672,9 +672,9 @@ struct TestIndexSampler : public samplers::Sampler<TestIndex> {
|
||||
};
|
||||
|
||||
TEST(DataTest, CanUseCustomTypeAsIndexType) {
|
||||
const size_t kBatchSize = 10;
|
||||
const int kBatchSize = 10;
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
TestIndexDataset(23), kBatchSize, TestIndexSampler(23));
|
||||
TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
|
||||
|
||||
size_t i = 0;
|
||||
for (auto batch : *data_loader) {
|
||||
@ -948,7 +948,7 @@ TEST(DataLoaderTest, RespectsTimeout) {
|
||||
ASSERT_LT(duration.count(), 1);
|
||||
}
|
||||
|
||||
// https://stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
|
||||
// stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
|
||||
struct Barrier {
|
||||
explicit Barrier(size_t target) : counter_(target) {}
|
||||
void wait() {
|
||||
@ -973,12 +973,12 @@ struct Barrier {
|
||||
// thread (for outside consumption) is not deterministic. Imagine the sampler is
|
||||
// a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
|
||||
// will be a single "job". Inside the dataloader, worker threads block until a
|
||||
// job is available. It is not deterministic which worker thread wakes up first
|
||||
// job is available. It is not deterministic which worker thread wakes up batch
|
||||
// to dequeue a particular batch. Further, some worker threads may take longer
|
||||
// than others to read the data for their index. As such, it could be that
|
||||
// worker thread 2 finishes before all other threads and returns its batch to
|
||||
// the main thread. In that case, the dataloader iterator would return the datum
|
||||
// at index 2 first, and afterwards the datum from whatever thread finishes
|
||||
// at index 2 batch, and afterwards the datum from whatever thread finishes
|
||||
// next. As such, the user may see data from indices 2, 0, 3, 1. On another run
|
||||
// of the same dataloader on the same data, threads may be scheduled differently
|
||||
// and return in order 0, 2, 3, 1. To force this ordering to deterministically
|
||||
@ -996,7 +996,7 @@ struct Barrier {
|
||||
// `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
|
||||
// has a copy of the dataset, and thus `get_batch()` is called on the
|
||||
// thread-local copy in each worker. We want to simulate out-of-order completion
|
||||
// of these threads. For this, we first set a barrier in the `get_batch()`
|
||||
// of these threads. For this, we batch set a barrier in the `get_batch()`
|
||||
// method to make sure every worker has some index to fetch assigned. Further,
|
||||
// each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
|
||||
// There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
|
||||
@ -1057,12 +1057,11 @@ struct Dataset : datasets::BatchDataset<Dataset, size_t> {
|
||||
TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
ordering_test::Dataset{},
|
||||
torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
|
||||
DataLoaderOptions()
|
||||
.batch_size(1)
|
||||
.workers(ordering_test::kNumberOfWorkers)
|
||||
.enforce_ordering(true),
|
||||
torch::data::samplers::SequentialSampler(
|
||||
ordering_test::kNumberOfWorkers));
|
||||
.enforce_ordering(true));
|
||||
std::vector<size_t> output;
|
||||
for (size_t value : *data_loader) {
|
||||
output.push_back(value);
|
||||
@ -1104,8 +1103,8 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
|
||||
}
|
||||
};
|
||||
|
||||
auto data_loader =
|
||||
torch::data::make_data_loader(D{}, DataLoaderOptions().workers(2));
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
|
||||
auto iterator = data_loader->begin();
|
||||
|
||||
try {
|
||||
@ -1119,3 +1118,159 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
|
||||
std::rethrow_exception(e.original_exception), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
|
||||
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
||||
|
||||
struct D : datasets::StatefulDataset<D, int, size_t> {
|
||||
torch::optional<int> get_batch(size_t) override {
|
||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||
return counter++;
|
||||
}
|
||||
return torch::nullopt;
|
||||
}
|
||||
torch::optional<size_t> size() const override {
|
||||
return 100;
|
||||
}
|
||||
void reset() override {
|
||||
counter = 0;
|
||||
}
|
||||
int counter = 0;
|
||||
};
|
||||
|
||||
auto data_loader = torch::data::make_data_loader(D{});
|
||||
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
const auto number_of_iterations =
|
||||
std::distance(data_loader->begin(), data_loader->end());
|
||||
ASSERT_EQ(
|
||||
number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
|
||||
<< "epoch " << i;
|
||||
}
|
||||
|
||||
for (const int i : *data_loader) {
|
||||
ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
|
||||
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
||||
const int kNumberOfWorkers = 4;
|
||||
|
||||
struct D : datasets::StatefulDataset<D, int, size_t> {
|
||||
torch::optional<int> get_batch(size_t) override {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||
return counter++;
|
||||
}
|
||||
return torch::nullopt;
|
||||
}
|
||||
torch::optional<size_t> size() const override {
|
||||
return 100;
|
||||
}
|
||||
void reset() override {
|
||||
counter = 0;
|
||||
}
|
||||
int counter = 0;
|
||||
std::mutex mutex;
|
||||
};
|
||||
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
torch::data::datasets::make_shared_dataset<D>(),
|
||||
DataLoaderOptions().workers(kNumberOfWorkers));
|
||||
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
const auto number_of_iterations =
|
||||
std::distance(data_loader->begin(), data_loader->end());
|
||||
ASSERT_EQ(
|
||||
number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
|
||||
<< "epoch " << i;
|
||||
}
|
||||
|
||||
for (const int i : *data_loader) {
|
||||
ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DataLoaderTest, StatefulDatasetWithMap) {
|
||||
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
||||
|
||||
struct D : datasets::StatefulDataset<D, int, size_t> {
|
||||
torch::optional<int> get_batch(size_t) override {
|
||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||
return counter++;
|
||||
}
|
||||
return torch::nullopt;
|
||||
}
|
||||
torch::optional<size_t> size() const override {
|
||||
return 100;
|
||||
}
|
||||
void reset() override {
|
||||
counter = 0;
|
||||
}
|
||||
int counter = 0;
|
||||
};
|
||||
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
D().map(transforms::BatchLambda<int, std::string>(
|
||||
[](int x) { return std::to_string(x); }))
|
||||
.map(transforms::BatchLambda<std::string, torch::Tensor>(
|
||||
[](const std::string& x) {
|
||||
return torch::tensor(static_cast<int64_t>(std::stoi(x)));
|
||||
})),
|
||||
DataLoaderOptions{});
|
||||
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
const auto number_of_iterations =
|
||||
std::distance(data_loader->begin(), data_loader->end());
|
||||
ASSERT_EQ(
|
||||
number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
|
||||
<< "epoch " << i;
|
||||
}
|
||||
|
||||
for (const torch::Tensor& t : *data_loader) {
|
||||
ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DataLoaderTest, StatefulDatasetWithCollate) {
|
||||
const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
|
||||
|
||||
struct D : datasets::StatefulDataset<D> {
|
||||
torch::optional<std::vector<Example<>>> get_batch(
|
||||
size_t batch_size) override {
|
||||
if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
|
||||
counter += batch_size;
|
||||
std::vector<Example<>> batch(
|
||||
/*count=*/batch_size,
|
||||
Example<>{torch::ones(batch_size + 1),
|
||||
torch::zeros(batch_size - 1)});
|
||||
return batch;
|
||||
}
|
||||
return torch::nullopt;
|
||||
}
|
||||
torch::optional<size_t> size() const override {
|
||||
return 100;
|
||||
}
|
||||
void reset() override {
|
||||
counter = 0;
|
||||
}
|
||||
int counter = 0;
|
||||
};
|
||||
|
||||
auto d = D().map(transforms::Stack<Example<>>());
|
||||
|
||||
const size_t kBatchSize = 5;
|
||||
|
||||
// Notice that the `get_batch()` of the dataset returns a vector<Example>, but
|
||||
// the `Stack` collation stacks the tensors into one.
|
||||
torch::optional<Example<>> batch = d.get_batch(kBatchSize);
|
||||
ASSERT_TRUE(batch.has_value());
|
||||
ASSERT_EQ(batch->data.size(0), kBatchSize);
|
||||
ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
|
||||
ASSERT_EQ(batch->target.size(0), kBatchSize);
|
||||
ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
|
||||
|
||||
ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
|
||||
ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
|
||||
}
|
||||
|
Reference in New Issue
Block a user