Compare commits

...

1 Commits

Author SHA1 Message Date
a0b03005aa Reapply "distributed/debug: add an HTTP server for debugging running jobs (#167395)"
This reverts commit 1c1638297e06444e60942719c35ddfb7a9133cea.
2025-11-17 13:34:16 -08:00
14 changed files with 629 additions and 7 deletions

View File

@ -402,3 +402,6 @@ scikit-build==0.18.1
pyre-extensions==0.0.32
tabulate==0.9.0
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
Jinja2==3.1.6
#Description: required for torch.distributed.debug

View File

@ -987,6 +987,24 @@ In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `T
collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the
{func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs.
### torch.distributed.debug HTTP Server
The `torch.distributed.debug` module provides a HTTP server that can be used to debug distributed applications. The server can
be started by calling {func}`torch.distributed.debug.start_debug_server`. This
allows users to collect data across all workers at runtime.
```{eval-rst}
.. automodule:: torch.distributed.debug
:members:
:undoc-members:
:show-inheritance:
:special-members: __init__
:member-order: bysource
```
## Logging
In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log

View File

@ -0,0 +1,56 @@
# Owner(s): ["oncall: distributed"]
import os
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import torch
import torch.distributed as dist
from torch.distributed.debug import start_debug_server, stop_debug_server
from torch.testing._internal.common_utils import run_tests, TestCase
session = requests.Session()
retry_strategy = Retry(total=5, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
class TestDebug(TestCase):
def test_basics(self) -> None:
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(store.port)
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
port = 25999
def fetch(path: str) -> str:
resp = session.get(f"http://localhost:{port}{path}")
resp.raise_for_status()
return resp.text
start_debug_server(port=port)
self.assertIn("torch profiler", fetch("/"))
self.assertIn("View 0", fetch("/profile?duration=0.01"))
self.assertIn("test_basics", fetch("/stacks"))
self.assertIn("pg_status", fetch("/fr_trace"))
if torch.cuda.is_available():
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
# test errors
resp = session.get(f"http://localhost:{port}/blah")
self.assertEqual(resp.status_code, 404)
self.assertIn("Handler not found: /blah", resp.text)
stop_debug_server()
if __name__ == "__main__":
run_tests()

View File

@ -100,7 +100,9 @@ class Logger:
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
port: int
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
@ -206,6 +208,7 @@ class Store:
desired_value: str,
) -> bytes: ...
def delete_key(self, key: str) -> bool: ...
def multi_get(self, keys: list[str]) -> list[bytes]: ...
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
@overload
@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend):
def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ...
class _Request:
def body(self) -> bytes: ...
def get_param(self, str) -> str: ...
class _Response:
def set_content(self, content: str | bytes, content_type: str) -> None: ...
def set_status(self, status: int) -> None: ...
def _register_handler(
name: str, handler: Callable[[_Request, _Response], None]
) -> None: ...

View File

@ -60,6 +60,7 @@ class _ExperimentalConfig:
verbose: bool = ...,
performance_events: list[str] = ...,
enable_cuda_sync_events: bool = ...,
profile_all_threads: bool = ...,
) -> None: ...
class ProfilerConfig:

View File

@ -1,5 +1,7 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setStatus(200);
}};
RegisterHandler frTracehandler(
"fr_trace_json",
[](const Request&, Response& res) {
auto trace = ::c10d::dump_fr_trace_json(true, true);
res.setContent(std::move(trace), "application/json");
res.setStatus(200);
});
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {

View File

@ -18,6 +18,14 @@ class TORCH_API Request {
virtual const std::string& body() const = 0;
virtual const std::multimap<std::string, std::string>& params() const = 0;
std::string getParam(const std::string& key) const {
auto it = params().find(key);
if (it != params().end()) {
return it->second;
}
return "";
}
};
// Response represents a response to the handler. This conceptually maps to an

View File

@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
TORCH_CHECK(
server_.bind_to_port(hostOrFile, 80),
fmt::format("Error binding to {}", hostOrFile));
} else if (port == 0) {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
port_ = server_.bind_to_any_port(hostOrFile);
TORCH_CHECK(
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
TORCH_CHECK(
server_.bind_to_port(hostOrFile, port),
fmt::format("Error binding to {}:{}", hostOrFile, port));
port_ = port;
}
serverThread_ = std::thread([this]() {

View File

@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
void shutdown();
int port() {
return port_;
}
private:
httplib::Server server_;
std::thread serverThread_;
int port_;
};
} // namespace c10d::control_plane

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
.def_property_readonly(
"port", &::c10d::control_plane::WorkerServer::port);
module.def(
"_get_handler",
@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Returns the handler with the specified name.
)");
module.def(
"_register_handler",
[](const std::string& name, const py::function& handler) {
::c10d::control_plane::registerHandler(
name,
[handler](
const ::c10d::control_plane::Request& req,
::c10d::control_plane::Response& res) {
py::gil_scoped_acquire acquire;
handler(std::ref(req), std::ref(res));
});
},
py::arg("name"),
py::arg("handler"),
R"(
Registers a handler by name.
)");
module.def(
"_get_handler_names",
&::c10d::control_plane::getHandlerNames,
@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
// Default constructor.
.def(py::init<>())
.def("body", &::c10d::control_plane::Request::body)
.def("params", &::c10d::control_plane::Request::params);
.def("get_param", &::c10d::control_plane::Request::getParam);
py::class_<
::c10d::control_plane::Response,
std::shared_ptr<::c10d::control_plane::Response>,
PythonResponse>(
py::class_<::c10d::control_plane::Response, PythonResponse>(
module,
"_Response",
R"(

View File

@ -0,0 +1,82 @@
import logging
import multiprocessing
import socket
# import for registration side effect
import torch.distributed.debug._handlers # noqa: F401
from torch._C._distributed_c10d import _WorkerServer
from torch.distributed.debug._store import get_rank, tcpstore_client
__all__ = [
"start_debug_server",
"stop_debug_server",
]
logger: logging.Logger = logging.getLogger(__name__)
_WORKER_SERVER: _WorkerServer | None = None
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
"""
Start the debug server stack on all workers. The frontend debug server is
only started on rank0 while the per rank worker servers are started on all
ranks.
This server provides an HTTP frontend that allows for debugging slow and
deadlocked distributed jobs across all ranks simultaneously. This collects
data such as stack traces, FlightRecorder events, and performance profiles.
WARNING: This is intended to only be used in trusted network environments.
The debug server is not designed to be secure and should not be exposed to
the public internet. See SECURITY.md for more details.
WARNING: This is an experimental feature and may change at any time.
Args:
port (int): The port to start the frontend debug server on.
worker_port (int): The port to start the worker server on. Defaults to 0, which
will cause the worker server to bind to an ephemeral port.
"""
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _WORKER_SERVER is None, "debug server already started"
assert _DEBUG_SERVER_PROC is None, "debug server already started"
logger.info("Starting debug server on port %d", port)
store = tcpstore_client()
_WORKER_SERVER = _WorkerServer("::", worker_port)
RANK = get_rank()
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
from torch.distributed.debug._frontend import main
if RANK == 0:
_DEBUG_SERVER_PROC = multiprocessing.Process(
target=main, args=(port,), daemon=True
)
_DEBUG_SERVER_PROC.start()
def stop_debug_server() -> None:
"""
Shutdown the debug server and stop the frontend debug server process.
"""
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _DEBUG_SERVER_PROC is not None
assert _WORKER_SERVER is not None
logger.info("Stopping debug server")
_DEBUG_SERVER_PROC.terminate()
_WORKER_SERVER.shutdown()
_DEBUG_SERVER_PROC.join()
_WORKER_SERVER = None
_DEBUG_SERVER_PROC = None

View File

@ -0,0 +1,353 @@
import json
import logging
import socket
import threading
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import parse_qs, urlparse
import requests
from jinja2 import DictLoader, Environment
from torch.distributed.debug._store import get_world_size, tcpstore_client
logger: logging.Logger = logging.getLogger(__name__)
def fetch_all(
endpoint: str, args: str = ""
) -> tuple[list[str], Iterator[requests.Response]]:
store = tcpstore_client()
keys = [f"rank{r}" for r in range(get_world_size())]
addrs = store.multi_get(keys)
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
with ThreadPoolExecutor(max_workers=10) as executor:
resps = executor.map(requests.post, addrs)
return addrs, resps
def format_json(blob: str):
parsed = json.loads(blob)
return json.dumps(parsed, indent=2)
templates = {
"base.html": """
<!doctype html>
<head>
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
<style>
body {
margin: 0;
font-family:
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
font-size: 1rem;
font-weight: 400;
line-height: 1.5;
color: #212529;
text-align: left;
background-color: #fff;
}
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
margin-bottom: .5rem;
font-weight: 500;
line-height: 1.2;
}
nav {
background-color: rgba(0, 0, 0, 0.17);
padding: 10px;
display: flex;
align-items: center;
padding: 16px;
justify-content: flex-start;
}
nav h1 {
display: inline-block;
margin: 0;
}
nav a {
margin: 0 8px;
}
section {
max-width: 1280px;
padding: 16px;
margin: 0 auto;
}
pre {
white-space: pre-wrap;
max-width: 100%;
}
</style>
</head>
<nav>
<h1>Torch Distributed Debug Server</h1>
<a href="/">Home</a> <!--@lint-ignore-->
<a href="/stacks">Python Stack Traces</a> <!--@lint-ignore-->
<a href="/fr_trace">FlightRecorder</a> <!--@lint-ignore-->
<a href="/fr_trace_nccl">FlightRecorder NCCL</a> <!--@lint-ignore-->
<a href="/profile">torch profiler</a> <!--@lint-ignore-->
</nav>
<section class="content">
{% block header %}{% endblock %}
{% block content %}{% endblock %}
</section>
""",
"index.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}Index{% endblock %}</h1>
{% endblock %}
{% block content %}
Hi
{% endblock %}
""",
"raw_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{title}}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ resp.text }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"json_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{ title }}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ format_json(resp.text) }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"profile.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}torch.profiler{% endblock %}</h1>
{% endblock %}
{% block content %}
<form action="/profile" method="get">
<label for="duration">Duration (seconds):</label>
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
<input type="submit" value="Submit">
</form>
<script>
function stringToArrayBuffer(str) {
const encoder = new TextEncoder();
return encoder.encode(str).buffer;
}
async function openPerfetto(data) {
const ui = window.open('https://ui.perfetto.dev/#!/');
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
// Perfetto readiness handshake: PING until we receive PONG
await new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.source === ui && e.data === 'PONG') {
window.removeEventListener('message', onMsg);
clearInterval(pinger);
resolve();
}
};
window.addEventListener('message', onMsg);
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
ui.postMessage({
perfetto: {
buffer: stringToArrayBuffer(JSON.stringify(data)),
title: "torch profiler",
fileName: "trace.json",
}
}, '*');
}
</script>
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<script>
function run{{ i }}() {
var data = {{ resp.text | safe }};
openPerfetto(data);
}
</script>
<button onclick="run{{ i }}()">View {{ i }}</button>
{% endif %}
{% endfor %}
{% endblock %}
""",
}
class _IPv6HTTPServer(ThreadingHTTPServer):
address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore
request_queue_size: int = 1024
class HTTPRequestHandler(BaseHTTPRequestHandler):
frontend: "FrontendServer"
def do_GET(self):
self.frontend._handle_request(self)
def get_path(self) -> str:
return urlparse(self.path).path
def get_query(self) -> dict[str, list[str]]:
return parse_qs(urlparse(self.path).query)
def get_query_arg(
self, name: str, default: object = None, type: type = str
) -> object:
query = self.get_query()
if name not in query:
return default
return type(query[name][0])
class FrontendServer:
def __init__(self, port: int):
# Setup templates
loader = DictLoader(templates)
self._jinja_env = Environment(loader=loader, enable_async=True)
self._jinja_env.globals.update(
zip=zip,
format_json=format_json,
enumerate=enumerate,
)
# Create routes
self._routes = {
"/": self._handle_index,
"/stacks": self._handle_stacks,
"/fr_trace": self._handle_fr_trace,
"/fr_trace_nccl": self._handle_fr_trace_nccl,
"/profile": self._handle_profiler,
}
# Create HTTP server
RequestHandlerClass = type(
"HTTPRequestHandler",
(HTTPRequestHandler,),
{"frontend": self},
)
server_address = ("", port)
self._server = _IPv6HTTPServer(server_address, RequestHandlerClass)
self._thread = threading.Thread(
target=self._serve,
args=(),
daemon=True,
)
self._thread.start()
def _serve(self) -> None:
try:
self._server.serve_forever()
except Exception:
logger.exception("got exception in checkpoint server")
def join(self) -> None:
self._thread.join()
def _handle_request(self, req: HTTPRequestHandler) -> None:
path = req.get_path()
if path not in self._routes:
req.send_error(404, f"Handler not found: {path}")
return
handler = self._routes[path]
try:
resp = handler(req)
except Exception as e:
logger.exception(
"Exception in checkpoint server when handling %s",
path,
)
req.send_error(500, str(e))
return
req.send_response(200)
req.send_header("Content-type", "text/html")
req.end_headers()
req.wfile.write(resp)
def _render_template(self, template: str, **kwargs: object) -> bytes:
return self._jinja_env.get_template(template).render(**kwargs).encode()
def _handle_index(self, req: HTTPRequestHandler) -> bytes:
return self._render_template("index.html")
def _handle_stacks(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("dump_traceback")
return self._render_template(
"raw_resp.html", title="Stacks", addrs=addrs, resps=resps
)
def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("fr_trace_json")
return self._render_template(
"json_resp.html",
title="FlightRecorder",
addrs=addrs,
resps=resps,
)
def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
return self._render_template(
"json_resp.html",
title="FlightRecorder NCCL",
addrs=addrs,
resps=resps,
)
def _handle_profiler(self, req: HTTPRequestHandler) -> bytes:
duration = req.get_query_arg("duration", default=1.0, type=float)
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
return self._render_template("profile.html", addrs=addrs, resps=resps)
def main(port: int) -> None:
server = FrontendServer(port=port)
logger.info("Frontend server started on port %d", server._server.server_port)
server.join()

View File

@ -0,0 +1,22 @@
import tempfile
import time
from torch._C._distributed_c10d import _register_handler, _Request, _Response
from torch.profiler import _ExperimentalConfig, profile
def _torch_profile(req: _Request, resp: _Response) -> None:
experimental_config = _ExperimentalConfig(
profile_all_threads=True,
)
duration = float(req.get_param("duration"))
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
time.sleep(duration)
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
prof.export_chrome_trace(f.name)
resp.set_content(open(f.name, "rb").read(), "application/json")
resp.set_status(200)
_register_handler("torch_profile", _torch_profile)

View File

@ -0,0 +1,24 @@
import os
import torch.distributed as dist
def get_rank() -> int:
return int(os.environ["RANK"])
def get_world_size() -> int:
return int(os.environ["WORLD_SIZE"])
def tcpstore_client() -> dist.Store:
MASTER_ADDR = os.environ["MASTER_ADDR"]
MASTER_PORT = int(os.environ["MASTER_PORT"])
store = dist.TCPStore(
host_name=MASTER_ADDR,
port=MASTER_PORT,
is_master=False,
)
store = dist.PrefixStore("debug_server", store)
return store