#include #include #include #include #include namespace torch { namespace jit { namespace mobile { SequentialSampler::SequentialSampler(size_t size) : size_(size) {} void SequentialSampler::reset(std::optional new_size) { if (new_size.has_value()) { size_ = *new_size; } index_ = 0; } optional> SequentialSampler::next(size_t batch_size) { const auto remaining_indices = size_ - index_; if (remaining_indices == 0) { return nullopt; } std::vector index_batch(std::min(batch_size, remaining_indices)); for (auto& i : index_batch) { i = index_++; } return index_batch; } void SequentialSampler::save(serialize::OutputArchive& archive) const { TORCH_CHECK( false, "Serialization of SequentialSampler not supported on mobile."); } void SequentialSampler::load(serialize::InputArchive& archive) { TORCH_CHECK( false, "Serialization of SequentialSampler not supported on mobile."); } size_t SequentialSampler::index() const noexcept { return index_; } } // namespace mobile } // namespace jit } // namespace torch