mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This series of changes try to cover C style casts into C++ alternatives. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750 Approved by: https://github.com/Skylion007
496 lines
13 KiB
C++
496 lines
13 KiB
C++
#include <c10/util/error.h>
|
|
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
|
|
|
#include <fcntl.h>
|
|
#include <sys/stat.h>
|
|
#include <cassert>
|
|
#include <cstdint>
|
|
|
|
#ifdef _WIN32
|
|
#include <c10/util/FileSystem.h>
|
|
#include <c10/util/win32-headers.h>
|
|
#include <fileapi.h>
|
|
#include <io.h>
|
|
#else
|
|
#include <sys/file.h>
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
#include <chrono>
|
|
#include <cstdio>
|
|
#include <thread>
|
|
#include <utility>
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
#define SYSASSERT(rv, ...) \
|
|
if ((rv) < 0) { \
|
|
C10_THROW_ERROR(DistStoreError, c10::utils::str_error(errno)); \
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
#define LOCK_EX 0x00000001
|
|
#define LOCK_SH 0x00000010
|
|
#define LOCK_UN 0x00000100
|
|
|
|
#if defined(_WIN32) && defined(USE_ROCM)
|
|
static
|
|
#endif
|
|
int
|
|
flock_(int fd, int op) {
|
|
HANDLE hdl = (HANDLE)_get_osfhandle(fd);
|
|
DWORD low = 1, high = 0;
|
|
OVERLAPPED offset = {0, 0, 0, 0, NULL};
|
|
|
|
if ((intptr_t)hdl < 0)
|
|
return -1;
|
|
|
|
switch (op) {
|
|
case LOCK_EX:
|
|
if (LockFileEx(hdl, LOCKFILE_EXCLUSIVE_LOCK, 0, low, high, &offset))
|
|
return 0;
|
|
break;
|
|
case LOCK_SH:
|
|
if (LockFileEx(hdl, 0, 0, low, high, &offset))
|
|
return 0;
|
|
break;
|
|
case LOCK_UN:
|
|
if (UnlockFileEx(hdl, 0, low, high, &offset) != 0)
|
|
return 0;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
errno = EINVAL;
|
|
return -1;
|
|
}
|
|
#endif
|
|
|
|
namespace c10d {
|
|
|
|
namespace {
|
|
|
|
template <typename F>
|
|
auto syscall(F fn) {
|
|
while (true) {
|
|
auto rv = fn();
|
|
if (rv == -1) {
|
|
if (errno == EINTR) {
|
|
continue;
|
|
}
|
|
}
|
|
return rv;
|
|
}
|
|
return typename std::invoke_result_t<F>{-1};
|
|
}
|
|
|
|
// For a comprehensive overview of file locking methods,
|
|
// see: https://gavv.github.io/blog/file-locks/.
|
|
// We stick to flock(2) here because we don't care about
|
|
// locking byte ranges and don't want locks to be process-wide.
|
|
|
|
// RAII wrapper around flock(2)
|
|
class Lock {
|
|
public:
|
|
explicit Lock(int fd, int operation) : fd_(fd) {
|
|
flock(operation);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
~Lock() {
|
|
unlock();
|
|
}
|
|
|
|
Lock(const Lock& that) = delete;
|
|
|
|
Lock& operator=(const Lock& other) = delete;
|
|
Lock& operator=(Lock&& other) noexcept {
|
|
if (this != &other) {
|
|
fd_ = other.fd_;
|
|
other.fd_ = -1;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
Lock(Lock&& other) noexcept {
|
|
*this = std::move(other);
|
|
}
|
|
|
|
void unlock() {
|
|
if (fd_ >= 0) {
|
|
flock(LOCK_UN);
|
|
fd_ = -1;
|
|
}
|
|
}
|
|
|
|
protected:
|
|
int fd_{-1};
|
|
|
|
void flock(int operation) {
|
|
#ifdef _WIN32
|
|
auto rv = syscall(std::bind(::flock_, fd_, operation));
|
|
#else
|
|
auto rv = syscall([this, operation] { return ::flock(fd_, operation); });
|
|
#endif
|
|
SYSASSERT(rv, "flock");
|
|
}
|
|
};
|
|
|
|
class File {
|
|
public:
|
|
explicit File(
|
|
const std::string& path,
|
|
int flags,
|
|
std::chrono::milliseconds timeout) {
|
|
const auto start = std::chrono::steady_clock::now();
|
|
while (true) {
|
|
#ifdef _WIN32
|
|
fd_ = syscall(std::bind(
|
|
::open, path.c_str(), flags | _O_BINARY, _S_IREAD | _S_IWRITE));
|
|
#else
|
|
fd_ = syscall([capture0 = path.c_str(), flags] {
|
|
return ::open(capture0, flags, 0644);
|
|
});
|
|
#endif
|
|
// Only retry when the file doesn't exist, since we are waiting for the
|
|
// file to be created in this case to address the following issue:
|
|
// https://github.com/pytorch/pytorch/issues/13750
|
|
if (fd_ >= 0 || errno != ENOENT) {
|
|
break;
|
|
}
|
|
#ifdef _WIN32
|
|
// if the parent folder doesn't exist it will never be able to create the
|
|
// file so we can skip the retry
|
|
if (!c10::filesystem::exists(c10::filesystem::path(path).parent_path())) {
|
|
break;
|
|
}
|
|
#endif
|
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
|
std::chrono::steady_clock::now() - start);
|
|
if (timeout != c10d::Store::kNoTimeout && elapsed > timeout) {
|
|
break;
|
|
}
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
}
|
|
SYSASSERT(fd_, "open(" + path + ")");
|
|
}
|
|
File(const File&) = delete;
|
|
File& operator=(const File&) = delete;
|
|
File(File&&) noexcept = delete;
|
|
File& operator=(File&&) noexcept = delete;
|
|
|
|
~File() {
|
|
::close(fd_);
|
|
}
|
|
|
|
Lock lockShared() {
|
|
return Lock(fd_, LOCK_SH);
|
|
}
|
|
|
|
Lock lockExclusive() {
|
|
return Lock(fd_, LOCK_EX);
|
|
}
|
|
|
|
off_t seek(off_t offset, int whence) {
|
|
auto rv =
|
|
syscall([this, offset, whence] { return lseek(fd_, offset, whence); });
|
|
SYSASSERT(rv, "lseek");
|
|
return rv;
|
|
}
|
|
|
|
off_t tell() {
|
|
auto rv = syscall([this] { return lseek(fd_, 0, SEEK_CUR); });
|
|
SYSASSERT(rv, "lseek");
|
|
return rv;
|
|
}
|
|
|
|
off_t size() {
|
|
auto pos = tell();
|
|
auto size = seek(0, SEEK_END);
|
|
seek(pos, SEEK_SET);
|
|
return size;
|
|
}
|
|
|
|
void write(const void* buf, size_t count) {
|
|
while (count > 0) {
|
|
auto rv =
|
|
syscall([this, buf, count] { return ::write(fd_, buf, count); });
|
|
SYSASSERT(rv, "write");
|
|
buf = (uint8_t*)buf + rv;
|
|
count -= rv;
|
|
}
|
|
}
|
|
|
|
void read(void* buf, size_t count) {
|
|
while (count > 0) {
|
|
auto rv = syscall([this, buf, count] { return ::read(fd_, buf, count); });
|
|
SYSASSERT(rv, "read");
|
|
buf = static_cast<uint8_t*>(buf) + rv;
|
|
count -= rv;
|
|
}
|
|
}
|
|
|
|
void write(const std::string& str) {
|
|
uint32_t len = str.size();
|
|
assert(str.size() <= std::numeric_limits<decltype(len)>::max());
|
|
write(&len, sizeof(len));
|
|
write(str.c_str(), len);
|
|
}
|
|
|
|
void write(const std::vector<uint8_t>& data) {
|
|
uint32_t len = data.size();
|
|
assert(data.size() <= std::numeric_limits<decltype(len)>::max());
|
|
write(&len, sizeof(len));
|
|
write(data.data(), len);
|
|
}
|
|
|
|
void read(std::string& str) {
|
|
uint32_t len = 0;
|
|
read(&len, sizeof(len));
|
|
std::vector<uint8_t> buf(len);
|
|
read(buf.data(), len);
|
|
str.assign(buf.begin(), buf.end());
|
|
}
|
|
|
|
void read(std::vector<uint8_t>& data) {
|
|
uint32_t len = 0;
|
|
read(&len, sizeof(len));
|
|
data.resize(len);
|
|
read(data.data(), len);
|
|
}
|
|
|
|
protected:
|
|
int fd_;
|
|
};
|
|
|
|
off_t refresh(
|
|
File& file,
|
|
off_t pos,
|
|
std::unordered_map<std::string, std::vector<uint8_t>>& cache,
|
|
const std::string& deletePrefix) {
|
|
auto size = file.size();
|
|
if (size != pos) {
|
|
std::string tmpKey;
|
|
std::vector<uint8_t> tmpValue;
|
|
file.seek(pos, SEEK_SET);
|
|
while (size > pos) {
|
|
file.read(tmpKey);
|
|
file.read(tmpValue);
|
|
if (tmpKey.compare(0, deletePrefix.size(), deletePrefix) == 0) {
|
|
cache.erase(tmpKey.substr(deletePrefix.size()));
|
|
} else {
|
|
cache[tmpKey] = std::move(tmpValue);
|
|
}
|
|
pos = file.tell();
|
|
}
|
|
}
|
|
file.seek(0, SEEK_SET);
|
|
return pos;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
FileStore::FileStore(std::string path, int numWorkers)
|
|
: path_(std::move(path)),
|
|
|
|
numWorkers_(numWorkers),
|
|
cleanupKey_("cleanup/"),
|
|
refCountKey_("refcount/"),
|
|
regularPrefix_("/"),
|
|
deletePrefix_("-") {
|
|
addHelper(refCountKey_, 1);
|
|
}
|
|
|
|
c10::intrusive_ptr<Store> FileStore::clone() {
|
|
return c10::make_intrusive<FileStore>(path_, numWorkers_);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
FileStore::~FileStore() {
|
|
// If the file does not exist - exit.
|
|
// This can happen when FileStore is invoked from python language which has
|
|
// GC. If python code has directory cleanup procedure, the race condition may
|
|
// occur between that code and this destructor. As a result, we check for
|
|
// file existence before cleanup
|
|
#ifdef _WIN32
|
|
int res = syscall(std::bind(::_access, path_.c_str(), 0));
|
|
#else
|
|
int res =
|
|
syscall([filepath = path_.c_str()] { return ::access(filepath, F_OK); });
|
|
#endif
|
|
if (res == -1) {
|
|
return;
|
|
}
|
|
|
|
// cleanup key will be different from all rest keys since all rest keys will
|
|
// have a regular prefix.
|
|
auto numFinishedWorker = addHelper(cleanupKey_, 1);
|
|
auto refCount = addHelper(refCountKey_, -1);
|
|
// The last worker cleans up the file. If numWorkers was not initialized to
|
|
// a specific positive value (i.e. meaning that there was not a fixed number
|
|
// of workers), we don't attempt to clean.
|
|
// Clean up the file if number of references is 0.
|
|
if (refCount == 0 && numWorkers_ >= 0 && numFinishedWorker >= numWorkers_) {
|
|
// Best effort removal without checking the return
|
|
::remove(path_.c_str());
|
|
}
|
|
}
|
|
|
|
void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
|
File file(path_, O_RDWR | O_CREAT, timeout_);
|
|
auto lock = file.lockExclusive();
|
|
file.seek(0, SEEK_END);
|
|
file.write(regKey);
|
|
file.write(value);
|
|
}
|
|
|
|
std::vector<uint8_t> FileStore::compareSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& expectedValue,
|
|
const std::vector<uint8_t>& desiredValue) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
|
File file(path_, O_RDWR | O_CREAT, timeout_);
|
|
auto lock = file.lockExclusive();
|
|
// Always refresh since even though the key exists in the cache,
|
|
// it might be outdated
|
|
pos_ = refresh(file, pos_, cache_, deletePrefix_);
|
|
if ((cache_.count(regKey) == 0 && expectedValue.empty()) ||
|
|
(cache_.count(regKey) != 0 && cache_[regKey] == expectedValue)) {
|
|
// if the key does not exist and currentValue arg is empty or
|
|
// the key does exist and current value is what is expected, then set it
|
|
file.seek(0, SEEK_END);
|
|
file.write(regKey);
|
|
file.write(desiredValue);
|
|
return desiredValue;
|
|
} else if (cache_.count(regKey) == 0) {
|
|
// if the key does not exist
|
|
return expectedValue;
|
|
}
|
|
// key exists but current value is not expected
|
|
return cache_[regKey];
|
|
}
|
|
|
|
std::vector<uint8_t> FileStore::get(const std::string& key) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
const auto start = std::chrono::steady_clock::now();
|
|
while (true) {
|
|
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
|
File file(path_, O_RDONLY, timeout_);
|
|
auto lock = file.lockShared();
|
|
auto size = file.size();
|
|
if (cache_.count(regKey) == 0 && size == pos_) {
|
|
// No new entries; release the shared lock and sleep for a bit
|
|
lock.unlock();
|
|
l.unlock();
|
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
|
std::chrono::steady_clock::now() - start);
|
|
if (timeout_ != kNoTimeout && elapsed > timeout_) {
|
|
auto err = c10::str(
|
|
"Timeout waiting for key: ",
|
|
key,
|
|
" after ",
|
|
timeout_.count(),
|
|
" ms");
|
|
TORCH_CHECK(false, err);
|
|
}
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
continue;
|
|
}
|
|
// Always refresh since even though the key exists in the cache,
|
|
// it might be outdated
|
|
pos_ = refresh(file, pos_, cache_, deletePrefix_);
|
|
if (cache_.count(regKey) != 0) {
|
|
return cache_[regKey];
|
|
}
|
|
}
|
|
}
|
|
|
|
int64_t FileStore::addHelper(const std::string& key, int64_t i) {
|
|
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
|
File file(path_, O_RDWR | O_CREAT, timeout_);
|
|
auto lock = file.lockExclusive();
|
|
pos_ = refresh(file, pos_, cache_, deletePrefix_);
|
|
|
|
const auto& value = cache_[key];
|
|
int64_t ti = i;
|
|
if (!value.empty()) {
|
|
auto buf = reinterpret_cast<const char*>(value.data());
|
|
auto len = value.size();
|
|
ti += std::stoll(std::string(buf, len));
|
|
}
|
|
// Always seek to the end to write
|
|
file.seek(0, SEEK_END);
|
|
// File cursor is at the end of the file now, and we have an
|
|
// exclusive lock, so we can write the new value.
|
|
file.write(key);
|
|
file.write(std::to_string(ti));
|
|
return ti;
|
|
}
|
|
|
|
int64_t FileStore::add(const std::string& key, int64_t value) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
return addHelper(regKey, value);
|
|
}
|
|
|
|
int64_t FileStore::getNumKeys() {
|
|
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
|
File file(path_, O_RDONLY, timeout_);
|
|
auto lock = file.lockShared();
|
|
pos_ = refresh(file, pos_, cache_, deletePrefix_);
|
|
return static_cast<int64_t>(cache_.size());
|
|
}
|
|
|
|
bool FileStore::deleteKey(const std::string& key) {
|
|
std::string deleteKey = deletePrefix_ + regularPrefix_ + key;
|
|
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
|
File file(path_, O_RDWR, timeout_);
|
|
auto lock = file.lockExclusive();
|
|
file.seek(0, SEEK_END);
|
|
file.write(deleteKey);
|
|
file.write(std::vector<uint8_t>{});
|
|
return true;
|
|
}
|
|
|
|
bool FileStore::check(const std::vector<std::string>& keys) {
|
|
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
|
File file(path_, O_RDONLY, timeout_);
|
|
auto lock = file.lockShared();
|
|
pos_ = refresh(file, pos_, cache_, deletePrefix_);
|
|
|
|
for (const auto& key : keys) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
if (cache_.count(regKey) == 0) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void FileStore::wait(const std::vector<std::string>& keys) {
|
|
wait(keys, timeout_);
|
|
}
|
|
|
|
void FileStore::wait(
|
|
const std::vector<std::string>& keys,
|
|
const std::chrono::milliseconds& timeout) {
|
|
// Not using inotify because it doesn't work on many
|
|
// shared filesystems (such as NFS).
|
|
const auto start = std::chrono::steady_clock::now();
|
|
while (!check(keys)) {
|
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
|
std::chrono::steady_clock::now() - start);
|
|
if (timeout != kNoTimeout && elapsed > timeout) {
|
|
TORCH_CHECK(false, "Wait timeout");
|
|
}
|
|
|
|
/* sleep override */
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
}
|
|
}
|
|
|
|
} // namespace c10d
|