mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This reverts commit 7646825c3eb687030c4f873b01312be0eed80174. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127805 Approved by: https://github.com/PaliC
This commit is contained in:
committed by
PyTorch MergeBot
parent
e76b28c765
commit
597922ba21
@ -769,6 +769,7 @@ cc_library(
|
||||
":caffe2",
|
||||
":torch_headers",
|
||||
"@kineto",
|
||||
"@cpp-httplib",
|
||||
] + if_cuda([
|
||||
"@cuda//:nvToolsExt",
|
||||
"@cutlass",
|
||||
|
@ -168,6 +168,12 @@ new_local_repository(
|
||||
path = "third_party/opentelemetry-cpp",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "cpp-httplib",
|
||||
build_file = "//third_party:cpp-httplib.BUILD",
|
||||
path = "third_party/cpp-httplib",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "tensorpipe",
|
||||
build_file = "//third_party:tensorpipe.BUILD",
|
||||
|
@ -515,6 +515,8 @@ libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/sequence_num.cpp",
|
||||
"torch/csrc/distributed/c10d/socket.cpp",
|
||||
"torch/csrc/distributed/c10d/Work.cpp",
|
||||
"torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
|
||||
"torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
|
||||
]
|
||||
|
||||
# These files are only supported on Linux (and others) but not on Windows.
|
||||
|
@ -1179,6 +1179,9 @@ if(USE_KINETO)
|
||||
${TORCH_ROOT}/third_party/kineto/libkineto/src)
|
||||
endif()
|
||||
|
||||
target_include_directories(torch_cpu PRIVATE
|
||||
${TORCH_ROOT}/third_party/cpp-httplib)
|
||||
|
||||
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
|
||||
FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp")
|
||||
|
@ -1681,3 +1681,7 @@ endif()
|
||||
|
||||
# Include google/FlatBuffers
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)
|
||||
|
||||
# Include cpp-httplib
|
||||
add_library(httplib INTERFACE IMPORTED)
|
||||
target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib)
|
||||
|
@ -29,6 +29,7 @@ Documentation
|
||||
elastic/metrics
|
||||
elastic/events
|
||||
elastic/subprocess_handler
|
||||
elastic/control_plane
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
10
docs/source/elastic/control_plane.rst
Normal file
10
docs/source/elastic/control_plane.rst
Normal file
@ -0,0 +1,10 @@
|
||||
Control Plane
|
||||
=============
|
||||
|
||||
.. automodule:: torch.distributed.elastic.control_plane
|
||||
.. currentmodule:: torch.distributed.elastic.control_plane
|
||||
|
||||
This module contains optional helpers that add extra debug and control handlers
|
||||
into your application.
|
||||
|
||||
.. autofunction:: torch.distributed.elastic.control_plane.worker_main
|
86
test/distributed/elastic/test_control_plane.py
Normal file
86
test/distributed/elastic/test_control_plane.py
Normal file
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
from urllib3.connection import HTTPConnection
|
||||
from urllib3.connectionpool import HTTPConnectionPool
|
||||
|
||||
from torch.distributed.elastic.control_plane import (
|
||||
TORCH_WORKER_SERVER_SOCKET,
|
||||
worker_main,
|
||||
)
|
||||
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
|
||||
|
||||
|
||||
class UnixHTTPConnection(HTTPConnection):
|
||||
def __init__(self, socket_path: str) -> None:
|
||||
super().__init__("localhost")
|
||||
|
||||
self.socket_path = socket_path
|
||||
|
||||
def connect(self) -> None:
|
||||
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.sock.connect(self.socket_path)
|
||||
|
||||
|
||||
class UnixHTTPConnectionPool(HTTPConnectionPool):
|
||||
def __init__(self, socket_path: str) -> None:
|
||||
super().__init__("localhost")
|
||||
|
||||
self.socket_path = socket_path
|
||||
|
||||
def _new_conn(self):
|
||||
return UnixHTTPConnection(self.socket_path)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def local_worker_server() -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "socket.sock")
|
||||
os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path
|
||||
|
||||
with worker_main():
|
||||
pool = UnixHTTPConnectionPool(socket_path)
|
||||
yield pool
|
||||
|
||||
|
||||
class WorkerServerTest(TestCase):
|
||||
def test_worker_server(self) -> None:
|
||||
with local_worker_server() as pool:
|
||||
resp = pool.request("GET", "/")
|
||||
self.assertEqual(resp.status, 200)
|
||||
self.assertEqual(
|
||||
resp.data,
|
||||
b"""<h1>torch.distributed.WorkerServer</h1>
|
||||
<a href="/handler/">Handler names</a>
|
||||
""",
|
||||
)
|
||||
|
||||
resp = pool.request("POST", "/handler/ping")
|
||||
self.assertEqual(resp.status, 200)
|
||||
self.assertEqual(resp.data, b"pong")
|
||||
|
||||
resp = pool.request("GET", "/handler/")
|
||||
self.assertEqual(resp.status, 200)
|
||||
self.assertIn("ping", json.loads(resp.data))
|
||||
|
||||
resp = pool.request("POST", "/handler/nonexistant")
|
||||
self.assertEqual(resp.status, 404)
|
||||
self.assertIn(b"Handler nonexistant not found:", resp.data)
|
||||
|
||||
@requires_cuda
|
||||
def test_dump_nccl_trace_pickle(self) -> None:
|
||||
with local_worker_server() as pool:
|
||||
resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
|
||||
self.assertEqual(resp.status, 200)
|
||||
out = pickle.loads(resp.data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
10
third_party/cpp-httplib.BUILD
vendored
Normal file
10
third_party/cpp-httplib.BUILD
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
|
||||
cc_library(
|
||||
name = "cpp-httplib",
|
||||
hdrs = ["httplib.h"],
|
||||
includes = [
|
||||
"/",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
@ -68,6 +68,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES
|
||||
${TORCH_ROOT}/third_party/onnx
|
||||
${TORCH_ROOT}/third_party/flatbuffers/include
|
||||
${TORCH_ROOT}/third_party/kineto/libkineto/include
|
||||
${TORCH_ROOT}/third_party/cpp-httplib
|
||||
|
||||
${TORCH_SRC_DIR}/csrc
|
||||
${TORCH_SRC_DIR}/csrc/api/include
|
||||
@ -80,6 +81,7 @@ set(TORCH_PYTHON_LINK_LIBRARIES
|
||||
Python::Module
|
||||
pybind::pybind11
|
||||
opentelemetry::api
|
||||
httplib
|
||||
shm
|
||||
fmt::fmt-header-only
|
||||
ATEN_CPU_FILES_GEN_LIB)
|
||||
|
@ -94,6 +94,10 @@ class Logger:
|
||||
def _set_uneven_input_join(self) -> None: ...
|
||||
def _set_static_graph(self) -> None: ...
|
||||
|
||||
class _WorkerServer:
|
||||
def __init__(self, socket_path: str) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
||||
def get_debug_level(): ...
|
||||
def set_debug_level(): ...
|
||||
def set_debug_level_from_env(): ...
|
||||
|
@ -28,6 +28,7 @@
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||
#include <torch/csrc/distributed/c10d/TraceUtils.h>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logger.hpp>
|
||||
#include <torch/torch.h>
|
||||
|
||||
@ -369,6 +370,13 @@ std::string dump_nccl_trace() {
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO(c-p-i-o): add a JSON endpoint.
|
||||
control_plane::RegisterHandler dumpHandler{
|
||||
"dump_nccl_trace_pickle",
|
||||
[](const control_plane::Request&, control_plane::Response& res) {
|
||||
res.setContent(dump_nccl_trace(), "application/octet-stream");
|
||||
}};
|
||||
|
||||
std::optional<std::function<void(std::function<void(const std::string&)>)>>&
|
||||
get_cpp_trace_dumper() {
|
||||
static std::optional<
|
||||
|
75
torch/csrc/distributed/c10d/control_plane/Handlers.cpp
Normal file
75
torch/csrc/distributed/c10d/control_plane/Handlers.cpp
Normal file
@ -0,0 +1,75 @@
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
|
||||
namespace {
|
||||
|
||||
class HandlerRegistry {
|
||||
public:
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
std::unique_lock<std::shared_mutex> lock(handlersMutex_);
|
||||
|
||||
if (handlers_.find(name) != handlers_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("Handler {} already registered", name));
|
||||
}
|
||||
|
||||
handlers_[name] = f;
|
||||
}
|
||||
|
||||
HandlerFunc getHandler(const std::string& name) {
|
||||
std::shared_lock<std::shared_mutex> lock(handlersMutex_);
|
||||
|
||||
auto it = handlers_.find(name);
|
||||
if (it == handlers_.end()) {
|
||||
throw std::runtime_error(fmt::format("Failed to find handler {}", name));
|
||||
}
|
||||
return handlers_[name];
|
||||
}
|
||||
|
||||
std::vector<std::string> getHandlerNames() {
|
||||
std::shared_lock<std::shared_mutex> lock(handlersMutex_);
|
||||
|
||||
std::vector<std::string> names;
|
||||
for (const auto& [name, _] : handlers_) {
|
||||
names.push_back(name);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_mutex handlersMutex_{};
|
||||
std::unordered_map<std::string, HandlerFunc> handlers_{};
|
||||
};
|
||||
|
||||
HandlerRegistry& getHandlerRegistry() {
|
||||
static HandlerRegistry registry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
|
||||
res.setContent("pong", "text/plain");
|
||||
}};
|
||||
|
||||
} // namespace
|
||||
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
return getHandlerRegistry().registerHandler(name, f);
|
||||
}
|
||||
|
||||
HandlerFunc getHandler(const std::string& name) {
|
||||
return getHandlerRegistry().getHandler(name);
|
||||
}
|
||||
|
||||
std::vector<std::string> getHandlerNames() {
|
||||
return getHandlerRegistry().getHandlerNames();
|
||||
}
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
67
torch/csrc/distributed/c10d/control_plane/Handlers.hpp
Normal file
67
torch/csrc/distributed/c10d/control_plane/Handlers.hpp
Normal file
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
|
||||
// Request represents a request to the handler. This conceptually maps to an
|
||||
// HTTP request but could be called via other transports.
|
||||
class TORCH_API Request {
|
||||
public:
|
||||
virtual ~Request() = default;
|
||||
|
||||
virtual const std::string& body() = 0;
|
||||
};
|
||||
|
||||
// Response represents a response to the handler. This conceptually maps to an
|
||||
// HTTP response but could be called via other transports.
|
||||
class TORCH_API Response {
|
||||
public:
|
||||
virtual ~Response() = default;
|
||||
|
||||
// Set the response body to the provided string.
|
||||
// TODO: add support for chunked responses
|
||||
virtual void setContent(
|
||||
std::string&& content,
|
||||
const std::string& content_type) = 0;
|
||||
|
||||
// Set the response status code.
|
||||
// These should match standard HTTP status codes.
|
||||
virtual void setStatus(int status) = 0;
|
||||
};
|
||||
|
||||
using HandlerFunc = std::function<void(const Request&, Response&)>;
|
||||
|
||||
// Registers a handler. The name needs to be unique and can be called by using
|
||||
// getHandler directly or via WorkerServer for remote requests.
|
||||
// These handlers are called from a background C++ thread concurrently with the
|
||||
// main thread. These handlers need to be thread safe and not cause issues
|
||||
// during Python training.
|
||||
TORCH_API void registerHandler(const std::string& name, HandlerFunc f);
|
||||
|
||||
// Fetches a handler by name.
|
||||
TORCH_API HandlerFunc getHandler(const std::string& name);
|
||||
|
||||
TORCH_API std::vector<std::string> getHandlerNames();
|
||||
|
||||
// Registers a handler statically.
|
||||
// See registerHandler for more details.
|
||||
class TORCH_API RegisterHandler {
|
||||
public:
|
||||
RegisterHandler(const std::string& name, HandlerFunc f) {
|
||||
registerHandler(name, f);
|
||||
}
|
||||
|
||||
// disable move, copy
|
||||
RegisterHandler(const RegisterHandler&) = delete;
|
||||
RegisterHandler(RegisterHandler&&) = delete;
|
||||
RegisterHandler& operator=(const RegisterHandler&) = delete;
|
||||
RegisterHandler& operator=(RegisterHandler&&) = delete;
|
||||
};
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
178
torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp
Normal file
178
torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp
Normal file
@ -0,0 +1,178 @@
|
||||
#include <filesystem>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <caffe2/utils/threadpool/WorkersPool.h>
|
||||
#include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
|
||||
namespace {
|
||||
class RequestImpl : public Request {
|
||||
public:
|
||||
RequestImpl(const httplib::Request& req) : req_(req) {}
|
||||
|
||||
const std::string& body() override {
|
||||
return req_.body;
|
||||
}
|
||||
|
||||
private:
|
||||
const httplib::Request& req_;
|
||||
};
|
||||
|
||||
class ResponseImpl : public Response {
|
||||
public:
|
||||
ResponseImpl(httplib::Response& res) : res_(res) {}
|
||||
|
||||
void setStatus(int status) override {
|
||||
res_.status = status;
|
||||
}
|
||||
|
||||
void setContent(std::string&& content, const std::string& content_type)
|
||||
override {
|
||||
res_.set_content(std::move(content), content_type);
|
||||
}
|
||||
|
||||
private:
|
||||
httplib::Response& res_;
|
||||
};
|
||||
|
||||
std::string jsonStrEscape(const std::string& str) {
|
||||
std::ostringstream ostream;
|
||||
for (char ch : str) {
|
||||
if (ch == '"') {
|
||||
ostream << "\\\"";
|
||||
} else if (ch == '\\') {
|
||||
ostream << "\\\\";
|
||||
} else if (ch == '\b') {
|
||||
ostream << "\\b";
|
||||
} else if (ch == '\f') {
|
||||
ostream << "\\f";
|
||||
} else if (ch == '\n') {
|
||||
ostream << "\\n";
|
||||
} else if (ch == '\r') {
|
||||
ostream << "\\r";
|
||||
} else if (ch == '\t') {
|
||||
ostream << "\\t";
|
||||
} else if ('\x00' <= ch && ch <= '\x1f') {
|
||||
ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
|
||||
<< static_cast<int>(ch);
|
||||
} else {
|
||||
ostream << ch;
|
||||
}
|
||||
}
|
||||
return ostream.str();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
WorkerServer::WorkerServer(const std::string& socketFile) {
|
||||
// using unix sockets
|
||||
server_.set_address_family(AF_UNIX);
|
||||
|
||||
// adjust keep alives as it stops the server from shutting down quickly
|
||||
server_.set_keep_alive_timeout(1); // second, default is 5
|
||||
server_.set_keep_alive_max_count(
|
||||
30); // wait max 30 seconds before closing socket
|
||||
|
||||
server_.Get("/", [](const httplib::Request& req, httplib::Response& res) {
|
||||
res.set_content(
|
||||
R"BODY(<h1>torch.distributed.WorkerServer</h1>
|
||||
<a href="/handler/">Handler names</a>
|
||||
)BODY",
|
||||
"text/html");
|
||||
});
|
||||
server_.Get(
|
||||
"/handler/", [](const httplib::Request& req, httplib::Response& res) {
|
||||
std::ostringstream body;
|
||||
body << "[";
|
||||
bool first = true;
|
||||
for (const auto& name : getHandlerNames()) {
|
||||
if (!first) {
|
||||
body << ",";
|
||||
}
|
||||
first = false;
|
||||
|
||||
body << "\"" << jsonStrEscape(name) << "\"";
|
||||
}
|
||||
body << "]";
|
||||
|
||||
res.set_content(body.str(), "application/json");
|
||||
});
|
||||
server_.Post(
|
||||
"/handler/:handler",
|
||||
[](const httplib::Request& req, httplib::Response& res) {
|
||||
auto handler_name = req.path_params.at("handler");
|
||||
HandlerFunc handler;
|
||||
try {
|
||||
handler = getHandler(handler_name);
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 404;
|
||||
res.set_content(
|
||||
fmt::format("Handler {} not found: {}", handler_name, e.what()),
|
||||
"text/plain");
|
||||
return;
|
||||
}
|
||||
RequestImpl torchReq{req};
|
||||
ResponseImpl torchRes{res};
|
||||
|
||||
try {
|
||||
handler(torchReq, torchRes);
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
res.set_content(
|
||||
fmt::format("Handler {} failed: {}", handler_name, e.what()),
|
||||
"text/plain");
|
||||
return;
|
||||
} catch (...) {
|
||||
res.status = 500;
|
||||
res.set_content(
|
||||
fmt::format(
|
||||
"Handler {} failed with unknown exception", handler_name),
|
||||
"text/plain");
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
if (std::filesystem::exists(socketFile)) {
|
||||
throw std::runtime_error(fmt::format("{} already exists", socketFile));
|
||||
}
|
||||
|
||||
C10D_WARNING("Server listening to {}", socketFile);
|
||||
if (!server_.bind_to_port(socketFile, 80)) {
|
||||
throw std::runtime_error(fmt::format("Error binding to {}", socketFile));
|
||||
}
|
||||
|
||||
serverThread_ = std::thread([this]() {
|
||||
try {
|
||||
if (!server_.listen_after_bind()) {
|
||||
throw std::runtime_error("failed to listen");
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
C10D_ERROR("Error while running server: {}", e.what());
|
||||
throw;
|
||||
}
|
||||
C10D_WARNING("Server exited");
|
||||
});
|
||||
}
|
||||
|
||||
void WorkerServer::shutdown() {
|
||||
C10D_WARNING("Server shutting down");
|
||||
server_.stop();
|
||||
serverThread_.join();
|
||||
}
|
||||
|
||||
WorkerServer::~WorkerServer() {
|
||||
if (serverThread_.joinable()) {
|
||||
C10D_WARNING("WorkerServer destructor called without shutdown");
|
||||
shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
28
torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp
Normal file
28
torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <httplib.h>
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
namespace c10d {
|
||||
namespace control_plane {
|
||||
|
||||
class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
WorkerServer(const std::string& socketFile);
|
||||
~WorkerServer();
|
||||
|
||||
void shutdown();
|
||||
|
||||
private:
|
||||
httplib::Server server_;
|
||||
std::thread serverThread_;
|
||||
};
|
||||
|
||||
} // namespace control_plane
|
||||
} // namespace c10d
|
@ -8,6 +8,7 @@
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
|
||||
#include <vector>
|
||||
#ifndef _WIN32
|
||||
#include <torch/csrc/distributed/c10d/HashStore.hpp>
|
||||
@ -3164,6 +3165,17 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
return py::bytes(::c10d::dump_nccl_trace());
|
||||
});
|
||||
#endif
|
||||
|
||||
intrusive_ptr_class_<::c10d::control_plane::WorkerServer>(
|
||||
module, "_WorkerServer", R"(
|
||||
)")
|
||||
.def(
|
||||
py::init([](const std::string& socketPath) {
|
||||
return c10::make_intrusive<::c10d::control_plane::WorkerServer>(
|
||||
socketPath);
|
||||
}),
|
||||
py::arg("socket_path"))
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
|
51
torch/distributed/elastic/control_plane.py
Normal file
51
torch/distributed/elastic/control_plane.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from typing import Generator
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
__all__ = [
|
||||
"worker_main",
|
||||
]
|
||||
|
||||
TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _worker_server(socket_path: str) -> Generator[None, None, None]:
|
||||
from torch._C._distributed_c10d import _WorkerServer
|
||||
|
||||
server = _WorkerServer(socket_path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
@contextmanager
|
||||
@record
|
||||
def worker_main() -> Generator[None, None, None]:
|
||||
"""
|
||||
This is a context manager that wraps your main entry function. This combines
|
||||
the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
|
||||
exposes handlers via a unix socket specified by
|
||||
``Torch_WORKER_SERVER_SOCKET``.
|
||||
|
||||
Example
|
||||
|
||||
::
|
||||
|
||||
@worker_main()
|
||||
def main():
|
||||
pass
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
with ExitStack() as stack:
|
||||
socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
|
||||
if socket_path is not None:
|
||||
stack.enter_context(_worker_server(socket_path))
|
||||
|
||||
yield
|
Reference in New Issue
Block a user