From 38bb262afdb2a5974be71404d493f533bd32e821 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 7 Nov 2025 17:11:21 -0800 Subject: [PATCH] distributed/debug: add an HTTP server for debugging running jobs --- .ci/docker/requirements-ci.txt | 3 + test/distributed/test_debug.py | 50 ++++ torch/_C/_distributed_c10d.pyi | 17 +- torch/_C/_profiler.pyi | 1 + .../c10d/control_plane/Handlers.cpp | 10 + .../c10d/control_plane/Handlers.hpp | 8 + .../c10d/control_plane/WorkerServer.cpp | 6 + .../c10d/control_plane/WorkerServer.hpp | 5 + torch/csrc/distributed/c10d/init.cpp | 31 +- torch/distributed/debug/__init__.py | 52 ++++ torch/distributed/debug/_flask.py | 265 ++++++++++++++++++ torch/distributed/debug/_handlers.py | 22 ++ torch/distributed/debug/_store.py | 24 ++ 13 files changed, 487 insertions(+), 7 deletions(-) create mode 100644 test/distributed/test_debug.py create mode 100644 torch/distributed/debug/__init__.py create mode 100644 torch/distributed/debug/_flask.py create mode 100644 torch/distributed/debug/_handlers.py create mode 100644 torch/distributed/debug/_store.py diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index f3636071714f..08363d2420e2 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -402,3 +402,6 @@ scikit-build==0.18.1 pyre-extensions==0.0.32 tabulate==0.9.0 #Description: These package are needed to build FBGEMM and torchrec on PyTorch CI + +Flask==3.1.2 +#Description: required for torch.distributed.debug diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py new file mode 100644 index 000000000000..b108e99324ef --- /dev/null +++ b/test/distributed/test_debug.py @@ -0,0 +1,50 @@ +# Owner(s): ["oncall: distributed"] + +import os + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +import torch.distributed as dist +from torch.distributed.debug import start_debug_server, stop_debug_server +from torch.testing._internal.common_utils import run_tests, TestCase + + +session = requests.Session() +retry_strategy = Retry(total=5, backoff_factor=0.5) +adapter = HTTPAdapter(max_retries=retry_strategy) +session.mount("http://", adapter) +session.mount("https://", adapter) + + +class TestDebug(TestCase): + def test_basics(self) -> None: + store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(store.port) + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + port = 25999 + + def fetch(path: str) -> str: + resp = session.get(f"http://localhost:{port}{path}") + resp.raise_for_status() + return resp.text + + print("starting!") + + start_debug_server(port=port) + + self.assertIn("torch profiler", fetch("/")) + self.assertIn("View 0", fetch("/profile?duration=0.01")) + self.assertIn("test_basics", fetch("/stacks")) + self.assertIn("pg_status", fetch("/fr_trace")) + self.assertIn("pg_status", fetch("/fr_trace_nccl")) + + stop_debug_server() + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index b659be9ee119..a80efc696e17 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -100,7 +100,9 @@ class Logger: def _set_static_graph(self) -> None: ... class _WorkerServer: - def __init__(self, socket_path: str) -> None: ... + port: int + + def __init__(self, host_or_file: str, port: int = ...) -> None: ... def shutdown(self) -> None: ... def get_debug_level(): ... @@ -206,6 +208,7 @@ class Store: desired_value: str, ) -> bytes: ... def delete_key(self, key: str) -> bool: ... + def multi_get(self, keys: list[str]) -> list[bytes]: ... def num_keys(self) -> int: ... def set_timeout(self, timeout: timedelta): ... @overload @@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend): def _set_process_group(pg: ProcessGroup) -> None: ... def _current_process_group() -> ProcessGroup: ... + +class _Request: + def body(self) -> bytes: ... + def get_param(self, str) -> str: ... + +class _Response: + def set_content(self, content: str | bytes, content_type: str) -> None: ... + def set_status(self, status: int) -> None: ... + +def _register_handler( + name: str, handler: Callable[[_Request, _Response], None] +) -> None: ... diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index d60d89a6a479..de12af50c185 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -60,6 +60,7 @@ class _ExperimentalConfig: verbose: bool = ..., performance_events: list[str] = ..., enable_cuda_sync_events: bool = ..., + profile_all_threads: bool = ..., ) -> None: ... class ProfilerConfig: diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 10274d053b99..fe8f831a23bb 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { res.setStatus(200); }}; +RegisterHandler frTracehandler( + "fr_trace_json", + [](const Request&, Response& res) { + auto trace = ::c10d::dump_fr_trace_json(true, true); + res.setContent(std::move(trace), "application/json"); + res.setStatus(200); + }); + } // namespace void registerHandler(const std::string& name, HandlerFunc f) { diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index 70333a3a4844..58ae9368ea21 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -18,6 +18,14 @@ class TORCH_API Request { virtual const std::string& body() const = 0; virtual const std::multimap& params() const = 0; + + std::string getParam(const std::string& key) const { + auto it = params().find(key); + if (it != params().end()) { + return it->second; + } + return ""; + } }; // Response represents a response to the handler. This conceptually maps to an diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 2f77bb119a95..f9a3034b0dd9 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { TORCH_CHECK( server_.bind_to_port(hostOrFile, 80), fmt::format("Error binding to {}", hostOrFile)); + } else if (port == 0) { + C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); + port_ = server_.bind_to_any_port(hostOrFile); + TORCH_CHECK( + port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port)); } else { C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); TORCH_CHECK( server_.bind_to_port(hostOrFile, port), fmt::format("Error binding to {}:{}", hostOrFile, port)); + port_ = port; } serverThread_ = std::thread([this]() { diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp index 41c1356fc01f..20d05b7509e9 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target { void shutdown(); + int port() { + return port_; + } + private: httplib::Server server_; std::thread serverThread_; + int port_; }; } // namespace c10d::control_plane diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 91bb3469e3e8..6f38cd9cd2c6 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -46,6 +46,7 @@ #include #include +#include #include #include #include @@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. }), py::arg("host_or_file"), py::arg("port") = -1) - .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown) + .def_property_readonly( + "port", &::c10d::control_plane::WorkerServer::port); module.def( "_get_handler", @@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`. Returns the handler with the specified name. )"); + module.def( + "_register_handler", + [](const std::string& name, const py::function& handler) { + ::c10d::control_plane::registerHandler( + name, + [handler]( + const ::c10d::control_plane::Request& req, + ::c10d::control_plane::Response& res) { + py::gil_scoped_acquire acquire; + handler(std::ref(req), std::ref(res)); + }); + }, + + py::arg("name"), + py::arg("handler"), + R"( + Registers a handler by name. + )"); + module.def( "_get_handler_names", &::c10d::control_plane::getHandlerNames, @@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. // Default constructor. .def(py::init<>()) .def("body", &::c10d::control_plane::Request::body) - .def("params", &::c10d::control_plane::Request::params); + .def("get_param", &::c10d::control_plane::Request::getParam); - py::class_< - ::c10d::control_plane::Response, - std::shared_ptr<::c10d::control_plane::Response>, - PythonResponse>( + py::class_<::c10d::control_plane::Response, PythonResponse>( module, "_Response", R"( diff --git a/torch/distributed/debug/__init__.py b/torch/distributed/debug/__init__.py new file mode 100644 index 000000000000..a99c57374e39 --- /dev/null +++ b/torch/distributed/debug/__init__.py @@ -0,0 +1,52 @@ +import multiprocessing +import socket + +# import for registration side effect +import torch.distributed.debug._handlers # noqa: F401 +from torch._C._distributed_c10d import _WorkerServer +from torch.distributed.debug._store import get_rank, tcpstore_client + + +__all__ = [ + "start_debug_server", + "stop_debug_server", +] + +_WORKER_SERVER: _WorkerServer | None = None +_DEBUG_SERVER_PROC: multiprocessing.Process | None = None + + +def start_debug_server(port: int = 25999, worker_port: int = 0) -> None: + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _WORKER_SERVER is None, "debug server already started" + assert _DEBUG_SERVER_PROC is None, "debug server already started" + + store = tcpstore_client() + + _WORKER_SERVER = _WorkerServer("::", worker_port) + + RANK = get_rank() + store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}") + + from torch.distributed.debug._flask import main + + if RANK == 0: + _DEBUG_SERVER_PROC = multiprocessing.Process( + target=main, args=(port,), daemon=True + ) + _DEBUG_SERVER_PROC.start() + + +def stop_debug_server() -> None: + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _DEBUG_SERVER_PROC is not None + assert _WORKER_SERVER is not None + + _DEBUG_SERVER_PROC.terminate() + _WORKER_SERVER.shutdown() + _DEBUG_SERVER_PROC.join() + + _WORKER_SERVER = None + _DEBUG_SERVER_PROC = None diff --git a/torch/distributed/debug/_flask.py b/torch/distributed/debug/_flask.py new file mode 100644 index 000000000000..3fb4c7d895a7 --- /dev/null +++ b/torch/distributed/debug/_flask.py @@ -0,0 +1,265 @@ +import json +from concurrent.futures import ThreadPoolExecutor +from typing import Iterator + +import requests +from flask import Flask, render_template, request +from jinja2 import DictLoader + +from torch.distributed.debug._store import get_world_size, tcpstore_client + + +def fetch_all( + endpoint: str, args: str = "" +) -> tuple[list[str], Iterator[requests.Response]]: + store = tcpstore_client() + keys = [f"rank{r}" for r in range(get_world_size())] + addrs = store.multi_get(keys) + addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] + + with ThreadPoolExecutor(max_workers=10) as executor: + resps = executor.map(requests.post, addrs) + + return addrs, resps + + +def format_json(blob: str): + parsed = json.loads(blob) + return json.dumps(parsed, indent=2) + + +templates = { + "base.html": """ + + + {% block title %}{% endblock %} - PyTorch Distributed + + + + + + + +
+ {% block header %}{% endblock %} + {% for message in get_flashed_messages() %} +
{{ message }}
+ {% endfor %} + {% block content %}{% endblock %} +
+ """, + "index.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}Index{% endblock %}

+{% endblock %} +{% block content %} +Hi +{% endblock %} + """, + "raw_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{title}}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "json_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ format_json(resp.text) }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "profile.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}torch.profiler{% endblock %}

+{% endblock %} + +{% block content %} +
+ + + +
+ + + + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} + + + + {% endif %} + {% endfor %} +{% endblock %} + """, +} + +app = Flask(__name__) +app.jinja_loader = DictLoader(templates) +app.jinja_env.globals.update( + zip=zip, + format_json=format_json, + enumerate=enumerate, +) + + +@app.route("/") +def _index_handler(): + return render_template("index.html") + + +@app.route("/stacks") +def _stacks_handler(): + addrs, resps = fetch_all("dump_traceback") + return render_template("raw_resp.html", title="Stacks", addrs=addrs, resps=resps) + + +@app.route("/fr_trace") +def _fr_trace_handler(): + addrs, resps = fetch_all("fr_trace_json") + + return render_template( + "json_resp.html", + title="FlightRecorder", + addrs=addrs, + resps=resps, + ) + + +@app.route("/fr_trace_nccl") +def _fr_trace_nccl_handler(): + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return render_template( + "json_resp.html", + title="FlightRecorder NCCL", + addrs=addrs, + resps=resps, + ) + + +@app.route("/profile") +def _profiler_handler(): + duration = request.args.get("duration", default=1.0, type=float) + + addrs, resps = fetch_all("torch_profile", f"duration={duration}") + + return render_template("profile.html", addrs=addrs, resps=resps) + + +def main(port: int) -> None: + app.run(host="::", port=port) diff --git a/torch/distributed/debug/_handlers.py b/torch/distributed/debug/_handlers.py new file mode 100644 index 000000000000..ba951b7bda07 --- /dev/null +++ b/torch/distributed/debug/_handlers.py @@ -0,0 +1,22 @@ +import tempfile +import time + +from torch._C._distributed_c10d import _register_handler, _Request, _Response +from torch.profiler import _ExperimentalConfig, profile + + +def _torch_profile(req: _Request, resp: _Response) -> None: + experimental_config = _ExperimentalConfig( + profile_all_threads=True, + ) + duration = float(req.get_param("duration")) + with profile(record_shapes=True, experimental_config=experimental_config) as prof: + time.sleep(duration) + + with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f: + prof.export_chrome_trace(f.name) + resp.set_content(open(f.name, "rb").read(), "application/json") + resp.set_status(200) + + +_register_handler("torch_profile", _torch_profile) diff --git a/torch/distributed/debug/_store.py b/torch/distributed/debug/_store.py new file mode 100644 index 000000000000..70c6cd0f3dde --- /dev/null +++ b/torch/distributed/debug/_store.py @@ -0,0 +1,24 @@ +import os + +import torch.distributed as dist + + +def get_rank() -> int: + return int(os.environ["RANK"]) + + +def get_world_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def tcpstore_client() -> dist.Store: + MASTER_ADDR = os.environ["MASTER_ADDR"] + MASTER_PORT = int(os.environ["MASTER_PORT"]) + + store = dist.TCPStore( + host_name=MASTER_ADDR, + port=MASTER_PORT, + is_master=False, + ) + store = dist.PrefixStore("debug_server", store) + return store