Compare commits

...

1 Commits

6 changed files with 348 additions and 58 deletions

View File

@ -100,7 +100,9 @@ class Logger:
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
port: int
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...

View File

@ -1,5 +1,7 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setStatus(200);
}};
RegisterHandler frTracehandler(
"fr_trace_json",
[](const Request&, Response& res) {
auto trace = ::c10d::dump_fr_trace_json(true, true);
res.setContent(std::move(trace), "application/json");
res.setStatus(200);
});
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {

View File

@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
TORCH_CHECK(
server_.bind_to_port(hostOrFile, 80),
fmt::format("Error binding to {}", hostOrFile));
} else if (port == 0) {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
port_ = server_.bind_to_any_port(hostOrFile);
TORCH_CHECK(
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
TORCH_CHECK(
server_.bind_to_port(hostOrFile, port),
fmt::format("Error binding to {}:{}", hostOrFile, port));
port_ = port;
}
serverThread_ = std::thread([this]() {

View File

@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
void shutdown();
int port() {
return port_;
}
private:
httplib::Server server_;
std::thread serverThread_;
int port_;
};
} // namespace c10d::control_plane

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
@ -381,8 +382,9 @@ void _register_comm_hook(
::c10d::Reducer& reducer,
py::object state,
py::object comm_hook) {
reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
std::move(state), std::move(comm_hook)));
reducer.register_comm_hook(
std::make_unique<::c10d::PythonCommHook>(
std::move(state), std::move(comm_hook)));
}
// Called from DDP's Python API to create a c10d C++ comm hook.
@ -882,37 +884,39 @@ This class does not support ``__members__`` property.)");
[](const ::c10d::ReduceOp& self, const py::dict& memo) {
return ::c10d::ReduceOp(self);
})
.def(py::pickle(
[](const ::c10d::ReduceOp& r) {
// __getstate__
if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return py::make_tuple(r.op_, py::none());
}
TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
const auto* preMulSupplement =
reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
r.supplement_.get());
if (!preMulSupplement->tensor_factor.defined()) {
return py::make_tuple(r.op_, preMulSupplement->double_factor);
} else {
return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
}
},
[](const py::tuple& t) {
// __setstate__
TORCH_CHECK(t.size() == 2, "Invalid state");
const auto op =
static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>());
if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return ::c10d::ReduceOp(op);
}
const auto preMulSupplement_factor = t[1];
if (py::isinstance<py::float_>(preMulSupplement_factor)) {
return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
} else {
return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
}
}));
.def(
py::pickle(
[](const ::c10d::ReduceOp& r) {
// __getstate__
if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return py::make_tuple(r.op_, py::none());
}
TORCH_CHECK(
r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
const auto* preMulSupplement =
reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
r.supplement_.get());
if (!preMulSupplement->tensor_factor.defined()) {
return py::make_tuple(r.op_, preMulSupplement->double_factor);
} else {
return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
}
},
[](const py::tuple& t) {
// __setstate__
TORCH_CHECK(t.size() == 2, "Invalid state");
const auto op = static_cast<::c10d::ReduceOp::RedOpType>(
t[0].cast<uint8_t>());
if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return ::c10d::ReduceOp(op);
}
const auto preMulSupplement_factor = t[1];
if (py::isinstance<py::float_>(preMulSupplement_factor)) {
return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
} else {
return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
}
}));
py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType")
.value("SUM", ::c10d::ReduceOp::RedOpType::SUM)
@ -3579,10 +3583,11 @@ Example::
[](std::optional<bool> includeCollectives,
std::optional<bool> includeStackTraces,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_xccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_xccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("includeStackTraces") = std::optional<bool>(),
@ -4112,8 +4117,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
"_dump_nccl_trace_json",
[](std::optional<bool> includeCollectives,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_nccl_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_nccl_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("onlyActive") = std::optional<bool>(),
@ -4130,10 +4136,11 @@ such as `dist.all_reduce(tensor, async_op=True)`.
[](std::optional<bool> includeCollectives,
std::optional<bool> includeStackTraces,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_nccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_nccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("includeStackTraces") = std::optional<bool>(),
@ -4157,8 +4164,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
"_dump_fr_trace_json",
[](std::optional<bool> includeCollectives,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_fr_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_fr_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("onlyActive") = std::optional<bool>(),
@ -4175,10 +4183,11 @@ such as `dist.all_reduce(tensor, async_op=True)`.
[](std::optional<bool> includeCollectives,
std::optional<bool> includeStackTraces,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_fr_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_fr_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("includeStackTraces") = std::optional<bool>(),
@ -4203,7 +4212,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
.def_property_readonly(
"port", &::c10d::control_plane::WorkerServer::port);
module.def(
"_get_handler",
@ -4219,6 +4230,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Returns the handler with the specified name.
)");
module.def(
"_register_handler",
[](const std::string& name, py::function handler) {
::c10d::control_plane::registerHandler(
name,
[handler](
const ::c10d::control_plane::Request& req,
::c10d::control_plane::Response& res) {
py::gil_scoped_acquire acquire;
handler(std::ref(req), std::ref(res));
});
},
py::arg("name"),
py::arg("handler"),
R"(
Registers a handler by name.
)");
module.def(
"_get_handler_names",
&::c10d::control_plane::getHandlerNames,
@ -4238,10 +4268,7 @@ such as `dist.all_reduce(tensor, async_op=True)`.
.def("body", &::c10d::control_plane::Request::body)
.def("params", &::c10d::control_plane::Request::params);
py::class_<
::c10d::control_plane::Response,
std::shared_ptr<::c10d::control_plane::Response>,
PythonResponse>(
py::class_<::c10d::control_plane::Response, PythonResponse>(
module,
"_Response",
R"(
@ -4267,9 +4294,10 @@ such as `dist.all_reduce(tensor, async_op=True)`.
} // namespace
// c10d methods on torch._C
static PyMethodDef methods[] = { // NOLINT
{"_c10d_init", c10d_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
static PyMethodDef methods[] =
{ // NOLINT
{"_c10d_init", c10d_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
// NOLINTNEXTLINE(misc-use-internal-linkage)
PyMethodDef* python_functions() {

239
torch/distributed/debug.py Normal file
View File

@ -0,0 +1,239 @@
import os
import socket
import multiprocessing
import requests
from concurrent.futures import ThreadPoolExecutor
import json
import time
import tempfile
from flask import Flask
import torch.distributed as dist
from torch.profiler import (
profile,
ProfilerActivity,
record_function,
_ExperimentalConfig,
)
from torch._C._distributed_c10d import _WorkerServer, _register_handler
def _torch_profile(req, resp):
experimental_config = _ExperimentalConfig(
profile_all_threads=True,
)
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
time.sleep(2)
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
prof.export_chrome_trace(f.name)
resp.set_content(open(f.name, "rb").read(), "application/json")
resp.set_status(200)
_register_handler("torch_profile", _torch_profile)
MASTER_ADDR = os.environ["MASTER_ADDR"]
MASTER_PORT = int(os.environ["MASTER_PORT"])
RANK = int(os.environ["RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
def _tcpstore_client() -> dist.Store:
store = dist.TCPStore(
host_name=MASTER_ADDR,
port=MASTER_PORT,
is_master=False,
)
store = dist.PrefixStore("debug_server", store)
return store
def fetch_all(endpoint: str) -> list[bytes]:
store = _tcpstore_client()
keys = [f"rank{r}" for r in range(WORLD_SIZE)]
addrs = store.multi_get(keys)
addrs = [f"{addr.decode()}/handler/{endpoint}" for addr in addrs]
with ThreadPoolExecutor(max_workers=10) as executor:
resps = executor.map(requests.post, addrs)
return addrs, resps
app = Flask(__name__)
def nav():
return """
<style>
body {
font-family: sans-serif;
}
pre {
white-space: pre-wrap;
max-width: 100%;
}
</style>
<h1>Torch Distributed Debug Server</h1>
<ul>
<li><a href="/">Home</a></li>
<li><a href="/stacks">Python Stack Traces</a></li>
<li><a href="/fr_trace">FlightRecorder</a></li>
<li><a href="/fr_trace_nccl">FlightRecorder NCCL</a></li>
<li><a href="/profile">torch profiler</a></li>
</ul>
"""
@app.route("/")
def index():
return nav()
@app.route("/stacks")
def stacks():
addrs, resps = fetch_all("dump_traceback")
def generate():
yield nav()
yield "<h2>Stacks</h2>"
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = resp.text
yield f"<pre>{stack}</pre>"
return generate()
def format_json(blob: str):
parsed = json.loads(blob)
return json.dumps(parsed, indent=2)
@app.route("/fr_trace")
def fr_trace():
addrs, resps = fetch_all("fr_trace_json")
def generate():
yield nav()
yield "<h2>FlightRecorder</h2>"
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = format_json(resp.text)
yield f"<pre>{stack}</pre>"
return generate()
@app.route("/fr_trace_nccl")
def fr_trace_nccl():
addrs, resps = fetch_all("dump_nccl_trace_json?onlyactive=true")
def generate():
yield nav()
yield "<h2>FlightRecorder NCCL</h2>"
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = format_json(resp.text)
yield f"<pre>{stack}</pre>"
return generate()
@app.route("/profile")
def profiler():
addrs, resps = fetch_all("torch_profile")
def generate():
yield nav()
yield """
<h2>torch profile</h2>
<script>
function stringToArrayBuffer(str) {
const encoder = new TextEncoder();
return encoder.encode(str).buffer;
}
async function openPerfetto(data) {
const ui = window.open('https://ui.perfetto.dev/#!/');
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
// Perfetto readiness handshake: PING until we receive PONG
await new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.source === ui && e.data === 'PONG') {
window.removeEventListener('message', onMsg);
clearInterval(pinger);
resolve();
}
};
window.addEventListener('message', onMsg);
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
ui.postMessage({
perfetto: {
buffer: stringToArrayBuffer(JSON.stringify(data)),
title: "torch profiler",
fileName: "trace.json",
}
}, '*');
}
</script>
"""
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = resp.text
yield f"""
<script>
function run{i}() {{
var data = {stack};
openPerfetto(data);
}}
</script>
<button onclick="run{i}()">View {i}</button>
"""
return generate()
def _interactive_server() -> None:
app.run(host="::", port=25999)
def enable_debug_server() -> None:
global _worker_server, _p
store = _tcpstore_client()
_worker_server = _WorkerServer("::", 0)
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_worker_server.port}")
if RANK == 0:
_p = multiprocessing.Process(
target=_interactive_server,
)
_p.start()