mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67746 Test Plan: Visual inspection. Sandcastle. Reviewed By: zertosh Differential Revision: D31986646 fbshipit-source-id: 91885c20c3cead3853c49abb9fe0a94a67f33cc8
191 lines
5.0 KiB
C++
191 lines
5.0 KiB
C++
#include <fcntl.h>
|
|
#include <poll.h>
|
|
#include <sys/mman.h>
|
|
#include <unistd.h>
|
|
#include <algorithm>
|
|
#include <cerrno>
|
|
#include <memory>
|
|
#include <set>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include <c10/util/tempfile.h>
|
|
|
|
#include <libshm/err.h>
|
|
#include <libshm/socket.h>
|
|
|
|
const int SHUTDOWN_TIMEOUT = 2000; // 2s
|
|
|
|
#ifdef DEBUG_LOG
|
|
#define COLOR "\033[31;1m"
|
|
#define RESET "\033[0m"
|
|
#define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__);
|
|
#define DEBUG(...) __DEBUG(__VA_ARGS__, '\n')
|
|
#else
|
|
#define DEBUG(...) (void)0
|
|
#endif
|
|
|
|
struct ClientSession {
|
|
ClientSession(ManagerSocket s) : socket(std::move(s)), pid(0) {}
|
|
|
|
ManagerSocket socket;
|
|
pid_t pid;
|
|
};
|
|
|
|
std::vector<struct pollfd> pollfds;
|
|
std::unordered_map<int, ClientSession> client_sessions;
|
|
// TODO: check if objects have been freed from time to time
|
|
std::set<std::string> used_objects;
|
|
|
|
void register_fd(int fd) {
|
|
struct pollfd pfd = {0};
|
|
pfd.fd = fd;
|
|
pfd.events = POLLIN;
|
|
pollfds.push_back(pfd);
|
|
}
|
|
|
|
void unregister_fd(int fd) {
|
|
pollfds.erase(
|
|
std::remove_if(
|
|
pollfds.begin(),
|
|
pollfds.end(),
|
|
[fd](const struct pollfd& pfd) { return pfd.fd == fd; }),
|
|
pollfds.end());
|
|
client_sessions.erase(fd);
|
|
}
|
|
|
|
void print_init_message(const char* message) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
size_t unused;
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
unused = write(1, message, strlen(message));
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
unused = write(1, "\n", 1);
|
|
}
|
|
|
|
bool object_exists(const char* name) {
|
|
int fd = shm_open(name, O_RDONLY, 0);
|
|
if (fd >= 0) {
|
|
close(fd);
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
void free_used_object(const std::string& name) {
|
|
if (!object_exists(name.c_str())) {
|
|
DEBUG("object %s appears to have been freed", name.c_str());
|
|
used_objects.erase(name);
|
|
} else {
|
|
DEBUG("object %s still exists", name.c_str());
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
int main(int argc, char* argv[]) {
|
|
setsid(); // Daemonize the process
|
|
|
|
std::unique_ptr<ManagerServerSocket> srv_socket;
|
|
c10::optional<c10::TempDir> tempdir;
|
|
try {
|
|
tempdir = c10::try_make_tempdir(/*name_prefix=*/"torch-shm-dir-");
|
|
if (!tempdir.has_value()) {
|
|
throw std::runtime_error(
|
|
"could not generate a random directory for manager socket");
|
|
}
|
|
|
|
std::string tempfile = tempdir->name + "/manager.sock";
|
|
|
|
srv_socket = std::make_unique<ManagerServerSocket>(tempfile);
|
|
register_fd(srv_socket->socket_fd);
|
|
print_init_message(tempfile.c_str());
|
|
DEBUG("opened socket %s", tempfile.c_str());
|
|
} catch (const std::exception& e) {
|
|
std::string message("ERROR: ");
|
|
message += e.what();
|
|
print_init_message(message.c_str());
|
|
return 1;
|
|
} catch (...) {
|
|
print_init_message("ERROR: unhandled exception");
|
|
return 1;
|
|
}
|
|
|
|
int timeout = -1;
|
|
std::vector<int> to_add;
|
|
std::vector<int> to_remove;
|
|
for (;;) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
int nevents;
|
|
if (client_sessions.size() == 0)
|
|
timeout = SHUTDOWN_TIMEOUT;
|
|
SYSCHECK_ERR_RETURN_NEG1(
|
|
nevents = poll(pollfds.data(), pollfds.size(), timeout));
|
|
timeout = -1;
|
|
if (nevents == 0 && client_sessions.size() == 0)
|
|
break;
|
|
|
|
for (auto& pfd : pollfds) {
|
|
if (pfd.revents & (POLLERR | POLLHUP)) {
|
|
// some process died
|
|
DEBUG("detaching process");
|
|
auto& session = client_sessions.at(pfd.fd);
|
|
(void)session;
|
|
DEBUG("%d has died", session.pid);
|
|
to_remove.push_back(pfd.fd);
|
|
} else if (pfd.revents & POLLIN) {
|
|
if (pfd.fd == srv_socket->socket_fd) {
|
|
// someone is joining
|
|
DEBUG("registered new client");
|
|
auto client = srv_socket->accept();
|
|
int fd = client.socket_fd;
|
|
to_add.push_back(fd);
|
|
client_sessions.emplace(fd, std::move(client));
|
|
} else {
|
|
// someone wants to register a segment
|
|
DEBUG("got alloc info");
|
|
auto& session = client_sessions.at(pfd.fd);
|
|
AllocInfo info = session.socket.receive();
|
|
session.pid = info.pid;
|
|
DEBUG(
|
|
"got alloc info: %d %d %s",
|
|
(int)info.free,
|
|
info.pid,
|
|
info.filename);
|
|
if (info.free) {
|
|
free_used_object(info.filename);
|
|
} else {
|
|
used_objects.insert(info.filename);
|
|
DEBUG("registered object %s", info.filename);
|
|
session.socket.confirm();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int fd : to_add)
|
|
register_fd(fd);
|
|
to_add.clear();
|
|
|
|
for (int fd : to_remove)
|
|
unregister_fd(fd);
|
|
to_remove.clear();
|
|
}
|
|
|
|
for (auto& obj_name : used_objects) {
|
|
DEBUG("freeing %s", obj_name.c_str());
|
|
shm_unlink(obj_name.c_str());
|
|
}
|
|
|
|
// Clean up file descriptors
|
|
for (auto& pfd : pollfds) {
|
|
unregister_fd(pfd.fd);
|
|
}
|
|
// Clean up manager.sock
|
|
srv_socket->remove();
|
|
// Clean up directory automatically
|
|
|
|
DEBUG("manager done");
|
|
return 0;
|
|
}
|