Reapply "distributed debug handlers (#126601)" (#127805)

This reverts commit 7646825c3eb687030c4f873b01312be0eed80174.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127805
Approved by: https://github.com/PaliC
This commit is contained in:
Tristan Rice
2024-06-04 19:44:30 +00:00
committed by PyTorch MergeBot
parent e76b28c765
commit 597922ba21
18 changed files with 548 additions and 0 deletions

View File

@ -769,6 +769,7 @@ cc_library(
":caffe2",
":torch_headers",
"@kineto",
"@cpp-httplib",
] + if_cuda([
"@cuda//:nvToolsExt",
"@cutlass",

View File

@ -168,6 +168,12 @@ new_local_repository(
path = "third_party/opentelemetry-cpp",
)
new_local_repository(
name = "cpp-httplib",
build_file = "//third_party:cpp-httplib.BUILD",
path = "third_party/cpp-httplib",
)
new_local_repository(
name = "tensorpipe",
build_file = "//third_party:tensorpipe.BUILD",

View File

@ -515,6 +515,8 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/sequence_num.cpp",
"torch/csrc/distributed/c10d/socket.cpp",
"torch/csrc/distributed/c10d/Work.cpp",
"torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
"torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
]
# These files are only supported on Linux (and others) but not on Windows.

View File

@ -1179,6 +1179,9 @@ if(USE_KINETO)
${TORCH_ROOT}/third_party/kineto/libkineto/src)
endif()
target_include_directories(torch_cpu PRIVATE
${TORCH_ROOT}/third_party/cpp-httplib)
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp")

View File

@ -1681,3 +1681,7 @@ endif()
# Include google/FlatBuffers
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)
# Include cpp-httplib
add_library(httplib INTERFACE IMPORTED)
target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib)

View File

@ -29,6 +29,7 @@ Documentation
elastic/metrics
elastic/events
elastic/subprocess_handler
elastic/control_plane
.. toctree::
:maxdepth: 1

View File

@ -0,0 +1,10 @@
Control Plane
=============
.. automodule:: torch.distributed.elastic.control_plane
.. currentmodule:: torch.distributed.elastic.control_plane
This module contains optional helpers that add extra debug and control handlers
into your application.
.. autofunction:: torch.distributed.elastic.control_plane.worker_main

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
# Owner(s): ["oncall: distributed"]
import json
import os
import pickle
import socket
import tempfile
from contextlib import contextmanager
from urllib3.connection import HTTPConnection
from urllib3.connectionpool import HTTPConnectionPool
from torch.distributed.elastic.control_plane import (
TORCH_WORKER_SERVER_SOCKET,
worker_main,
)
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
class UnixHTTPConnection(HTTPConnection):
def __init__(self, socket_path: str) -> None:
super().__init__("localhost")
self.socket_path = socket_path
def connect(self) -> None:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self.socket_path)
class UnixHTTPConnectionPool(HTTPConnectionPool):
def __init__(self, socket_path: str) -> None:
super().__init__("localhost")
self.socket_path = socket_path
def _new_conn(self):
return UnixHTTPConnection(self.socket_path)
@contextmanager
def local_worker_server() -> None:
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "socket.sock")
os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path
with worker_main():
pool = UnixHTTPConnectionPool(socket_path)
yield pool
class WorkerServerTest(TestCase):
def test_worker_server(self) -> None:
with local_worker_server() as pool:
resp = pool.request("GET", "/")
self.assertEqual(resp.status, 200)
self.assertEqual(
resp.data,
b"""<h1>torch.distributed.WorkerServer</h1>
<a href="/handler/">Handler names</a>
""",
)
resp = pool.request("POST", "/handler/ping")
self.assertEqual(resp.status, 200)
self.assertEqual(resp.data, b"pong")
resp = pool.request("GET", "/handler/")
self.assertEqual(resp.status, 200)
self.assertIn("ping", json.loads(resp.data))
resp = pool.request("POST", "/handler/nonexistant")
self.assertEqual(resp.status, 404)
self.assertIn(b"Handler nonexistant not found:", resp.data)
@requires_cuda
def test_dump_nccl_trace_pickle(self) -> None:
with local_worker_server() as pool:
resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
self.assertEqual(resp.status, 200)
out = pickle.loads(resp.data)
if __name__ == "__main__":
run_tests()

10
third_party/cpp-httplib.BUILD vendored Normal file
View File

@ -0,0 +1,10 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
cc_library(
name = "cpp-httplib",
hdrs = ["httplib.h"],
includes = [
"/",
],
visibility = ["//visibility:public"],
)

View File

@ -68,6 +68,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES
${TORCH_ROOT}/third_party/onnx
${TORCH_ROOT}/third_party/flatbuffers/include
${TORCH_ROOT}/third_party/kineto/libkineto/include
${TORCH_ROOT}/third_party/cpp-httplib
${TORCH_SRC_DIR}/csrc
${TORCH_SRC_DIR}/csrc/api/include
@ -80,6 +81,7 @@ set(TORCH_PYTHON_LINK_LIBRARIES
Python::Module
pybind::pybind11
opentelemetry::api
httplib
shm
fmt::fmt-header-only
ATEN_CPU_FILES_GEN_LIB)

View File

@ -94,6 +94,10 @@ class Logger:
def _set_uneven_input_join(self) -> None: ...
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
def set_debug_level(): ...
def set_debug_level_from_env(): ...

View File

@ -28,6 +28,7 @@
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <torch/torch.h>
@ -369,6 +370,13 @@ std::string dump_nccl_trace() {
}
#endif
// TODO(c-p-i-o): add a JSON endpoint.
control_plane::RegisterHandler dumpHandler{
"dump_nccl_trace_pickle",
[](const control_plane::Request&, control_plane::Response& res) {
res.setContent(dump_nccl_trace(), "application/octet-stream");
}};
std::optional<std::function<void(std::function<void(const std::string&)>)>>&
get_cpp_trace_dumper() {
static std::optional<

View File

@ -0,0 +1,75 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
#include <stdexcept>
namespace c10d {
namespace control_plane {
namespace {
class HandlerRegistry {
public:
void registerHandler(const std::string& name, HandlerFunc f) {
std::unique_lock<std::shared_mutex> lock(handlersMutex_);
if (handlers_.find(name) != handlers_.end()) {
throw std::runtime_error(
fmt::format("Handler {} already registered", name));
}
handlers_[name] = f;
}
HandlerFunc getHandler(const std::string& name) {
std::shared_lock<std::shared_mutex> lock(handlersMutex_);
auto it = handlers_.find(name);
if (it == handlers_.end()) {
throw std::runtime_error(fmt::format("Failed to find handler {}", name));
}
return handlers_[name];
}
std::vector<std::string> getHandlerNames() {
std::shared_lock<std::shared_mutex> lock(handlersMutex_);
std::vector<std::string> names;
for (const auto& [name, _] : handlers_) {
names.push_back(name);
}
return names;
}
private:
std::shared_mutex handlersMutex_{};
std::unordered_map<std::string, HandlerFunc> handlers_{};
};
HandlerRegistry& getHandlerRegistry() {
static HandlerRegistry registry;
return registry;
}
RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setContent("pong", "text/plain");
}};
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {
return getHandlerRegistry().registerHandler(name, f);
}
HandlerFunc getHandler(const std::string& name) {
return getHandlerRegistry().getHandler(name);
}
std::vector<std::string> getHandlerNames() {
return getHandlerRegistry().getHandlerNames();
}
} // namespace control_plane
} // namespace c10d

View File

@ -0,0 +1,67 @@
#pragma once
#include <functional>
#include <string>
#include <c10/macros/Export.h>
namespace c10d {
namespace control_plane {
// Request represents a request to the handler. This conceptually maps to an
// HTTP request but could be called via other transports.
class TORCH_API Request {
public:
virtual ~Request() = default;
virtual const std::string& body() = 0;
};
// Response represents a response to the handler. This conceptually maps to an
// HTTP response but could be called via other transports.
class TORCH_API Response {
public:
virtual ~Response() = default;
// Set the response body to the provided string.
// TODO: add support for chunked responses
virtual void setContent(
std::string&& content,
const std::string& content_type) = 0;
// Set the response status code.
// These should match standard HTTP status codes.
virtual void setStatus(int status) = 0;
};
using HandlerFunc = std::function<void(const Request&, Response&)>;
// Registers a handler. The name needs to be unique and can be called by using
// getHandler directly or via WorkerServer for remote requests.
// These handlers are called from a background C++ thread concurrently with the
// main thread. These handlers need to be thread safe and not cause issues
// during Python training.
TORCH_API void registerHandler(const std::string& name, HandlerFunc f);
// Fetches a handler by name.
TORCH_API HandlerFunc getHandler(const std::string& name);
TORCH_API std::vector<std::string> getHandlerNames();
// Registers a handler statically.
// See registerHandler for more details.
class TORCH_API RegisterHandler {
public:
RegisterHandler(const std::string& name, HandlerFunc f) {
registerHandler(name, f);
}
// disable move, copy
RegisterHandler(const RegisterHandler&) = delete;
RegisterHandler(RegisterHandler&&) = delete;
RegisterHandler& operator=(const RegisterHandler&) = delete;
RegisterHandler& operator=(RegisterHandler&&) = delete;
};
} // namespace control_plane
} // namespace c10d

View File

@ -0,0 +1,178 @@
#include <filesystem>
#include <mutex>
#include <shared_mutex>
#include <sstream>
#include <tuple>
#include <unordered_map>
#include <ATen/core/interned_strings.h>
#include <caffe2/utils/threadpool/WorkersPool.h>
#include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
#include <torch/csrc/distributed/c10d/logging.h>
namespace c10d {
namespace control_plane {
namespace {
class RequestImpl : public Request {
public:
RequestImpl(const httplib::Request& req) : req_(req) {}
const std::string& body() override {
return req_.body;
}
private:
const httplib::Request& req_;
};
class ResponseImpl : public Response {
public:
ResponseImpl(httplib::Response& res) : res_(res) {}
void setStatus(int status) override {
res_.status = status;
}
void setContent(std::string&& content, const std::string& content_type)
override {
res_.set_content(std::move(content), content_type);
}
private:
httplib::Response& res_;
};
std::string jsonStrEscape(const std::string& str) {
std::ostringstream ostream;
for (char ch : str) {
if (ch == '"') {
ostream << "\\\"";
} else if (ch == '\\') {
ostream << "\\\\";
} else if (ch == '\b') {
ostream << "\\b";
} else if (ch == '\f') {
ostream << "\\f";
} else if (ch == '\n') {
ostream << "\\n";
} else if (ch == '\r') {
ostream << "\\r";
} else if (ch == '\t') {
ostream << "\\t";
} else if ('\x00' <= ch && ch <= '\x1f') {
ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
<< static_cast<int>(ch);
} else {
ostream << ch;
}
}
return ostream.str();
}
} // namespace
WorkerServer::WorkerServer(const std::string& socketFile) {
// using unix sockets
server_.set_address_family(AF_UNIX);
// adjust keep alives as it stops the server from shutting down quickly
server_.set_keep_alive_timeout(1); // second, default is 5
server_.set_keep_alive_max_count(
30); // wait max 30 seconds before closing socket
server_.Get("/", [](const httplib::Request& req, httplib::Response& res) {
res.set_content(
R"BODY(<h1>torch.distributed.WorkerServer</h1>
<a href="/handler/">Handler names</a>
)BODY",
"text/html");
});
server_.Get(
"/handler/", [](const httplib::Request& req, httplib::Response& res) {
std::ostringstream body;
body << "[";
bool first = true;
for (const auto& name : getHandlerNames()) {
if (!first) {
body << ",";
}
first = false;
body << "\"" << jsonStrEscape(name) << "\"";
}
body << "]";
res.set_content(body.str(), "application/json");
});
server_.Post(
"/handler/:handler",
[](const httplib::Request& req, httplib::Response& res) {
auto handler_name = req.path_params.at("handler");
HandlerFunc handler;
try {
handler = getHandler(handler_name);
} catch (const std::exception& e) {
res.status = 404;
res.set_content(
fmt::format("Handler {} not found: {}", handler_name, e.what()),
"text/plain");
return;
}
RequestImpl torchReq{req};
ResponseImpl torchRes{res};
try {
handler(torchReq, torchRes);
} catch (const std::exception& e) {
res.status = 500;
res.set_content(
fmt::format("Handler {} failed: {}", handler_name, e.what()),
"text/plain");
return;
} catch (...) {
res.status = 500;
res.set_content(
fmt::format(
"Handler {} failed with unknown exception", handler_name),
"text/plain");
return;
}
});
if (std::filesystem::exists(socketFile)) {
throw std::runtime_error(fmt::format("{} already exists", socketFile));
}
C10D_WARNING("Server listening to {}", socketFile);
if (!server_.bind_to_port(socketFile, 80)) {
throw std::runtime_error(fmt::format("Error binding to {}", socketFile));
}
serverThread_ = std::thread([this]() {
try {
if (!server_.listen_after_bind()) {
throw std::runtime_error("failed to listen");
}
} catch (std::exception& e) {
C10D_ERROR("Error while running server: {}", e.what());
throw;
}
C10D_WARNING("Server exited");
});
}
void WorkerServer::shutdown() {
C10D_WARNING("Server shutting down");
server_.stop();
serverThread_.join();
}
WorkerServer::~WorkerServer() {
if (serverThread_.joinable()) {
C10D_WARNING("WorkerServer destructor called without shutdown");
shutdown();
}
}
} // namespace control_plane
} // namespace c10d

View File

@ -0,0 +1,28 @@
#pragma once
#include <string>
#include <thread>
#include <unordered_map>
#include <httplib.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
namespace c10d {
namespace control_plane {
class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
public:
WorkerServer(const std::string& socketFile);
~WorkerServer();
void shutdown();
private:
httplib::Server server_;
std::thread serverThread_;
};
} // namespace control_plane
} // namespace c10d

View File

@ -8,6 +8,7 @@
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
#include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
#include <vector>
#ifndef _WIN32
#include <torch/csrc/distributed/c10d/HashStore.hpp>
@ -3164,6 +3165,17 @@ such as `dist.all_reduce(tensor, async_op=True)`.
return py::bytes(::c10d::dump_nccl_trace());
});
#endif
intrusive_ptr_class_<::c10d::control_plane::WorkerServer>(
module, "_WorkerServer", R"(
)")
.def(
py::init([](const std::string& socketPath) {
return c10::make_intrusive<::c10d::control_plane::WorkerServer>(
socketPath);
}),
py::arg("socket_path"))
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
Py_RETURN_TRUE;
}

View File

@ -0,0 +1,51 @@
import os
from contextlib import contextmanager, ExitStack
from typing import Generator
from torch.distributed.elastic.multiprocessing.errors import record
__all__ = [
"worker_main",
]
TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
@contextmanager
def _worker_server(socket_path: str) -> Generator[None, None, None]:
from torch._C._distributed_c10d import _WorkerServer
server = _WorkerServer(socket_path)
try:
yield
finally:
server.shutdown()
@contextmanager
@record
def worker_main() -> Generator[None, None, None]:
"""
This is a context manager that wraps your main entry function. This combines
the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
exposes handlers via a unix socket specified by
``Torch_WORKER_SERVER_SOCKET``.
Example
::
@worker_main()
def main():
pass
if __name__=="__main__":
main()
"""
with ExitStack() as stack:
socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
if socket_path is not None:
stack.enter_context(_worker_server(socket_path))
yield