mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66746 Modified loops in files under fbsource/fbcode/caffe2/ from the format `for(TYPE var=x0;var<x_max;x++)` to the format `for(const auto var: irange(xmax))` This was achieved by running r-barnes's loop upgrader script (D28874212) with some modification to exclude all files under /torch/jit and a number of reversions or unused variable suppression warnings added by hand. Test Plan: Sandcastle Reviewed By: malfet Differential Revision: D31705361 fbshipit-source-id: 33fd22eb03086d114e2c98e56703e8ec84460268
338 lines
9.3 KiB
C++
338 lines
9.3 KiB
C++
#ifndef CAFFE2_CORE_DB_H_
|
|
#define CAFFE2_CORE_DB_H_
|
|
|
|
#include <mutex>
|
|
|
|
#include <c10/util/Registry.h>
|
|
#include <c10/util/irange.h>
|
|
#include <c10/util/string_view.h>
|
|
#include "caffe2/core/blob_serialization.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
|
|
namespace caffe2 {
|
|
namespace db {
|
|
|
|
/**
|
|
* The mode of the database, whether we are doing a read, write, or creating
|
|
* a new database.
|
|
*/
|
|
enum Mode { READ, WRITE, NEW };
|
|
|
|
/**
|
|
* An abstract class for the cursor of the database while reading.
|
|
*/
|
|
class TORCH_API Cursor {
|
|
public:
|
|
Cursor() {}
|
|
virtual ~Cursor() {}
|
|
/**
|
|
* Seek to a specific key (or if the key does not exist, seek to the
|
|
* immediate next). This is optional for dbs, and in default, SupportsSeek()
|
|
* returns false meaning that the db cursor does not support it.
|
|
*/
|
|
virtual void Seek(const string& key) = 0;
|
|
virtual bool SupportsSeek() {
|
|
return false;
|
|
}
|
|
/**
|
|
* Seek to the first key in the database.
|
|
*/
|
|
virtual void SeekToFirst() = 0;
|
|
/**
|
|
* Go to the next location in the database.
|
|
*/
|
|
virtual void Next() = 0;
|
|
/**
|
|
* Returns the current key.
|
|
*/
|
|
virtual string key() = 0;
|
|
/**
|
|
* Returns the current value.
|
|
*/
|
|
virtual string value() = 0;
|
|
/**
|
|
* Returns whether the current location is valid - for example, if we have
|
|
* reached the end of the database, return false.
|
|
*/
|
|
virtual bool Valid() = 0;
|
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(Cursor);
|
|
};
|
|
|
|
/**
|
|
* An abstract class for the current database transaction while writing.
|
|
*/
|
|
class TORCH_API Transaction {
|
|
public:
|
|
Transaction() {}
|
|
virtual ~Transaction() {}
|
|
/**
|
|
* Puts the key value pair to the database.
|
|
*/
|
|
virtual void Put(const std::string& key, std::string&& value) = 0;
|
|
/**
|
|
* Commits the current writes.
|
|
*/
|
|
virtual void Commit() = 0;
|
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(Transaction);
|
|
};
|
|
|
|
/**
|
|
* An abstract class for accessing a database of key-value pairs.
|
|
*/
|
|
class TORCH_API DB {
|
|
public:
|
|
DB(const string& /*source*/, Mode mode) : mode_(mode) {}
|
|
virtual ~DB() {}
|
|
/**
|
|
* Closes the database.
|
|
*/
|
|
virtual void Close() = 0;
|
|
/**
|
|
* Returns a cursor to read the database. The caller takes the ownership of
|
|
* the pointer.
|
|
*/
|
|
virtual std::unique_ptr<Cursor> NewCursor() = 0;
|
|
/**
|
|
* Returns a transaction to write data to the database. The caller takes the
|
|
* ownership of the pointer.
|
|
*/
|
|
virtual std::unique_ptr<Transaction> NewTransaction() = 0;
|
|
|
|
/**
|
|
* Set DB options.
|
|
*
|
|
* These options should apply for the lifetime of the DB, or until a
|
|
* subsequent SetOptions() call overrides them.
|
|
*
|
|
* This is used by the Save operator to allow the client to pass in
|
|
* DB-specific options to control the behavior. This is an opaque string,
|
|
* where the format is specific to the DB type. DB types may pass in a
|
|
* serialized protobuf message here if desired.
|
|
*/
|
|
virtual void SetOptions(c10::string_view /* options */) {}
|
|
|
|
protected:
|
|
Mode mode_;
|
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(DB);
|
|
};
|
|
|
|
// Database classes are registered by their names so we can do optional
|
|
// dependencies.
|
|
C10_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
|
|
#define REGISTER_CAFFE2_DB(name, ...) \
|
|
C10_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__)
|
|
|
|
/**
|
|
* Returns a database object of the given database type, source and mode. The
|
|
* caller takes the ownership of the pointer. If the database type is not
|
|
* supported, a nullptr is returned. The caller is responsible for examining the
|
|
* validity of the pointer.
|
|
*/
|
|
inline unique_ptr<DB>
|
|
CreateDB(const string& db_type, const string& source, Mode mode) {
|
|
auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
|
|
VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type;
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* Returns whether or not a database exists given the database type and path.
|
|
*/
|
|
inline bool DBExists(const string& db_type, const string& full_db_name) {
|
|
// Warning! We assume that creating a DB throws an exception if the DB
|
|
// does not exist. If the DB constructor does not follow this design
|
|
// pattern,
|
|
// the returned output (the existence tensor) can be wrong.
|
|
try {
|
|
std::unique_ptr<DB> db(
|
|
caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
|
|
return true;
|
|
} catch (...) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* A reader wrapper for DB that also allows us to serialize it.
|
|
*/
|
|
class TORCH_API DBReader {
|
|
public:
|
|
friend class DBReaderSerializer;
|
|
DBReader() {}
|
|
|
|
DBReader(
|
|
const string& db_type,
|
|
const string& source,
|
|
const int32_t num_shards = 1,
|
|
const int32_t shard_id = 0) {
|
|
Open(db_type, source, num_shards, shard_id);
|
|
}
|
|
|
|
explicit DBReader(const DBReaderProto& proto) {
|
|
Open(proto.db_type(), proto.source());
|
|
if (proto.has_key()) {
|
|
CAFFE_ENFORCE(
|
|
cursor_->SupportsSeek(),
|
|
"Encountering a proto that needs seeking but the db type "
|
|
"does not support it.");
|
|
cursor_->Seek(proto.key());
|
|
}
|
|
num_shards_ = 1;
|
|
shard_id_ = 0;
|
|
}
|
|
|
|
explicit DBReader(std::unique_ptr<DB> db)
|
|
: db_type_("<memory-type>"),
|
|
source_("<memory-source>"),
|
|
db_(std::move(db)) {
|
|
CAFFE_ENFORCE(db_.get(), "Passed null db");
|
|
cursor_ = db_->NewCursor();
|
|
}
|
|
|
|
void Open(
|
|
const string& db_type,
|
|
const string& source,
|
|
const int32_t num_shards = 1,
|
|
const int32_t shard_id = 0) {
|
|
// Note(jiayq): resetting is needed when we re-open e.g. leveldb where no
|
|
// concurrent access is allowed.
|
|
cursor_.reset();
|
|
db_.reset();
|
|
db_type_ = db_type;
|
|
source_ = source;
|
|
db_ = CreateDB(db_type_, source_, READ);
|
|
CAFFE_ENFORCE(
|
|
db_,
|
|
"Cannot find db implementation of type ",
|
|
db_type,
|
|
" (while trying to open ",
|
|
source_,
|
|
")");
|
|
InitializeCursor(num_shards, shard_id);
|
|
}
|
|
|
|
void Open(
|
|
unique_ptr<DB>&& db,
|
|
const int32_t num_shards = 1,
|
|
const int32_t shard_id = 0) {
|
|
cursor_.reset();
|
|
db_.reset();
|
|
db_ = std::move(db);
|
|
CAFFE_ENFORCE(db_.get(), "Passed null db");
|
|
InitializeCursor(num_shards, shard_id);
|
|
}
|
|
|
|
public:
|
|
/**
|
|
* Read a set of key and value from the db and move to next. Thread safe.
|
|
*
|
|
* The string objects key and value must be created by the caller and
|
|
* explicitly passed in to this function. This saves one additional object
|
|
* copy.
|
|
*
|
|
* If the cursor reaches its end, the reader will go back to the head of
|
|
* the db. This function can be used to enable multiple input ops to read
|
|
* the same db.
|
|
*
|
|
* Note(jiayq): we loosen the definition of a const function here a little
|
|
* bit: the state of the cursor is actually changed. However, this allows
|
|
* us to pass in a DBReader to an Operator without the need of a duplicated
|
|
* output blob.
|
|
*/
|
|
void Read(string* key, string* value) const {
|
|
CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
|
|
std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
|
|
*key = cursor_->key();
|
|
*value = cursor_->value();
|
|
|
|
// In sharded mode, each read skips num_shards_ records
|
|
for (const auto s : c10::irange(num_shards_)) {
|
|
(void)s; // Suppress unused variable
|
|
cursor_->Next();
|
|
if (!cursor_->Valid()) {
|
|
MoveToBeginning();
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Seeks to the first key. Thread safe.
|
|
*/
|
|
void SeekToFirst() const {
|
|
CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
|
|
std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
|
|
MoveToBeginning();
|
|
}
|
|
|
|
/**
|
|
* Returns the underlying cursor of the db reader.
|
|
*
|
|
* Note that if you directly use the cursor, the read will not be thread
|
|
* safe, because there is no mechanism to stop multiple threads from
|
|
* accessing the same cursor. You should consider using Read() explicitly.
|
|
*/
|
|
inline Cursor* cursor() const {
|
|
VLOG(1) << "Usually for a DBReader you should use Read() to be "
|
|
"thread safe. Consider refactoring your code.";
|
|
return cursor_.get();
|
|
}
|
|
|
|
private:
|
|
void InitializeCursor(const int32_t num_shards, const int32_t shard_id) {
|
|
CAFFE_ENFORCE(num_shards >= 1);
|
|
CAFFE_ENFORCE(shard_id >= 0);
|
|
CAFFE_ENFORCE(shard_id < num_shards);
|
|
num_shards_ = num_shards;
|
|
shard_id_ = shard_id;
|
|
cursor_ = db_->NewCursor();
|
|
SeekToFirst();
|
|
}
|
|
|
|
void MoveToBeginning() const {
|
|
cursor_->SeekToFirst();
|
|
for (const auto s : c10::irange(shard_id_)) {
|
|
(void)s; // Suppress unused variable
|
|
cursor_->Next();
|
|
CAFFE_ENFORCE(
|
|
cursor_->Valid(), "Db has fewer rows than shard id: ", s, shard_id_);
|
|
}
|
|
}
|
|
|
|
string db_type_;
|
|
string source_;
|
|
unique_ptr<DB> db_;
|
|
unique_ptr<Cursor> cursor_;
|
|
mutable std::mutex reader_mutex_;
|
|
uint32_t num_shards_{};
|
|
uint32_t shard_id_{};
|
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(DBReader);
|
|
};
|
|
|
|
class TORCH_API DBReaderSerializer : public BlobSerializerBase {
|
|
public:
|
|
/**
|
|
* Serializes a DBReader. Note that this blob has to contain DBReader,
|
|
* otherwise this function produces a fatal error.
|
|
*/
|
|
void Serialize(
|
|
const void* pointer,
|
|
TypeMeta typeMeta,
|
|
const string& name,
|
|
BlobSerializerBase::SerializationAcceptor acceptor) override;
|
|
};
|
|
|
|
class TORCH_API DBReaderDeserializer : public BlobDeserializerBase {
|
|
public:
|
|
void Deserialize(const BlobProto& proto, Blob* blob) override;
|
|
};
|
|
|
|
} // namespace db
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_CORE_DB_H_
|