mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 14:54:56 +08:00
distributed/debug: add an HTTP server for debugging running jobs
This commit is contained in:
@ -402,3 +402,6 @@ scikit-build==0.18.1
|
||||
pyre-extensions==0.0.32
|
||||
tabulate==0.9.0
|
||||
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
|
||||
|
||||
Flask==3.1.2
|
||||
#Description: required for torch.distributed.debug
|
||||
|
||||
50
test/distributed/test_debug.py
Normal file
50
test/distributed/test_debug.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.debug import start_debug_server, stop_debug_server
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
session = requests.Session()
|
||||
retry_strategy = Retry(total=5, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
|
||||
class TestDebug(TestCase):
|
||||
def test_basics(self) -> None:
|
||||
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(store.port)
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
|
||||
port = 25999
|
||||
|
||||
def fetch(path: str) -> str:
|
||||
resp = session.get(f"http://localhost:{port}{path}")
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
|
||||
print("starting!")
|
||||
|
||||
start_debug_server(port=port)
|
||||
|
||||
self.assertIn("torch profiler", fetch("/"))
|
||||
self.assertIn("View 0", fetch("/profile?duration=0.01"))
|
||||
self.assertIn("test_basics", fetch("/stacks"))
|
||||
self.assertIn("pg_status", fetch("/fr_trace"))
|
||||
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
|
||||
|
||||
stop_debug_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -100,7 +100,9 @@ class Logger:
|
||||
def _set_static_graph(self) -> None: ...
|
||||
|
||||
class _WorkerServer:
|
||||
def __init__(self, socket_path: str) -> None: ...
|
||||
port: int
|
||||
|
||||
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
||||
def get_debug_level(): ...
|
||||
@ -206,6 +208,7 @@ class Store:
|
||||
desired_value: str,
|
||||
) -> bytes: ...
|
||||
def delete_key(self, key: str) -> bool: ...
|
||||
def multi_get(self, keys: list[str]) -> list[bytes]: ...
|
||||
def num_keys(self) -> int: ...
|
||||
def set_timeout(self, timeout: timedelta): ...
|
||||
@overload
|
||||
@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend):
|
||||
|
||||
def _set_process_group(pg: ProcessGroup) -> None: ...
|
||||
def _current_process_group() -> ProcessGroup: ...
|
||||
|
||||
class _Request:
|
||||
def body(self) -> bytes: ...
|
||||
def get_param(self, str) -> str: ...
|
||||
|
||||
class _Response:
|
||||
def set_content(self, content: str | bytes, content_type: str) -> None: ...
|
||||
def set_status(self, status: int) -> None: ...
|
||||
|
||||
def _register_handler(
|
||||
name: str, handler: Callable[[_Request, _Response], None]
|
||||
) -> None: ...
|
||||
|
||||
@ -60,6 +60,7 @@ class _ExperimentalConfig:
|
||||
verbose: bool = ...,
|
||||
performance_events: list[str] = ...,
|
||||
enable_cuda_sync_events: bool = ...,
|
||||
profile_all_threads: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
class ProfilerConfig:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
|
||||
res.setStatus(200);
|
||||
}};
|
||||
|
||||
RegisterHandler frTracehandler(
|
||||
"fr_trace_json",
|
||||
[](const Request&, Response& res) {
|
||||
auto trace = ::c10d::dump_fr_trace_json(true, true);
|
||||
res.setContent(std::move(trace), "application/json");
|
||||
res.setStatus(200);
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
|
||||
@ -18,6 +18,14 @@ class TORCH_API Request {
|
||||
virtual const std::string& body() const = 0;
|
||||
|
||||
virtual const std::multimap<std::string, std::string>& params() const = 0;
|
||||
|
||||
std::string getParam(const std::string& key) const {
|
||||
auto it = params().find(key);
|
||||
if (it != params().end()) {
|
||||
return it->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
// Response represents a response to the handler. This conceptually maps to an
|
||||
|
||||
@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, 80),
|
||||
fmt::format("Error binding to {}", hostOrFile));
|
||||
} else if (port == 0) {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
port_ = server_.bind_to_any_port(hostOrFile);
|
||||
TORCH_CHECK(
|
||||
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
} else {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, port),
|
||||
fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
port_ = port;
|
||||
}
|
||||
|
||||
serverThread_ = std::thread([this]() {
|
||||
|
||||
@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
|
||||
void shutdown();
|
||||
|
||||
int port() {
|
||||
return port_;
|
||||
}
|
||||
|
||||
private:
|
||||
httplib::Server server_;
|
||||
std::thread serverThread_;
|
||||
int port_;
|
||||
};
|
||||
|
||||
} // namespace c10d::control_plane
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||
@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
}),
|
||||
py::arg("host_or_file"),
|
||||
py::arg("port") = -1)
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
|
||||
.def_property_readonly(
|
||||
"port", &::c10d::control_plane::WorkerServer::port);
|
||||
|
||||
module.def(
|
||||
"_get_handler",
|
||||
@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
Returns the handler with the specified name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_register_handler",
|
||||
[](const std::string& name, const py::function& handler) {
|
||||
::c10d::control_plane::registerHandler(
|
||||
name,
|
||||
[handler](
|
||||
const ::c10d::control_plane::Request& req,
|
||||
::c10d::control_plane::Response& res) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
handler(std::ref(req), std::ref(res));
|
||||
});
|
||||
},
|
||||
|
||||
py::arg("name"),
|
||||
py::arg("handler"),
|
||||
R"(
|
||||
Registers a handler by name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_get_handler_names",
|
||||
&::c10d::control_plane::getHandlerNames,
|
||||
@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
// Default constructor.
|
||||
.def(py::init<>())
|
||||
.def("body", &::c10d::control_plane::Request::body)
|
||||
.def("params", &::c10d::control_plane::Request::params);
|
||||
.def("get_param", &::c10d::control_plane::Request::getParam);
|
||||
|
||||
py::class_<
|
||||
::c10d::control_plane::Response,
|
||||
std::shared_ptr<::c10d::control_plane::Response>,
|
||||
PythonResponse>(
|
||||
py::class_<::c10d::control_plane::Response, PythonResponse>(
|
||||
module,
|
||||
"_Response",
|
||||
R"(
|
||||
|
||||
52
torch/distributed/debug/__init__.py
Normal file
52
torch/distributed/debug/__init__.py
Normal file
@ -0,0 +1,52 @@
|
||||
import multiprocessing
|
||||
import socket
|
||||
|
||||
# import for registration side effect
|
||||
import torch.distributed.debug._handlers # noqa: F401
|
||||
from torch._C._distributed_c10d import _WorkerServer
|
||||
from torch.distributed.debug._store import get_rank, tcpstore_client
|
||||
|
||||
|
||||
__all__ = [
|
||||
"start_debug_server",
|
||||
"stop_debug_server",
|
||||
]
|
||||
|
||||
_WORKER_SERVER: _WorkerServer | None = None
|
||||
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
|
||||
|
||||
|
||||
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _WORKER_SERVER is None, "debug server already started"
|
||||
assert _DEBUG_SERVER_PROC is None, "debug server already started"
|
||||
|
||||
store = tcpstore_client()
|
||||
|
||||
_WORKER_SERVER = _WorkerServer("::", worker_port)
|
||||
|
||||
RANK = get_rank()
|
||||
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
|
||||
|
||||
from torch.distributed.debug._flask import main
|
||||
|
||||
if RANK == 0:
|
||||
_DEBUG_SERVER_PROC = multiprocessing.Process(
|
||||
target=main, args=(port,), daemon=True
|
||||
)
|
||||
_DEBUG_SERVER_PROC.start()
|
||||
|
||||
|
||||
def stop_debug_server() -> None:
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _DEBUG_SERVER_PROC is not None
|
||||
assert _WORKER_SERVER is not None
|
||||
|
||||
_DEBUG_SERVER_PROC.terminate()
|
||||
_WORKER_SERVER.shutdown()
|
||||
_DEBUG_SERVER_PROC.join()
|
||||
|
||||
_WORKER_SERVER = None
|
||||
_DEBUG_SERVER_PROC = None
|
||||
265
torch/distributed/debug/_flask.py
Normal file
265
torch/distributed/debug/_flask.py
Normal file
@ -0,0 +1,265 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator
|
||||
|
||||
import requests
|
||||
from flask import Flask, render_template, request
|
||||
from jinja2 import DictLoader
|
||||
|
||||
from torch.distributed.debug._store import get_world_size, tcpstore_client
|
||||
|
||||
|
||||
def fetch_all(
|
||||
endpoint: str, args: str = ""
|
||||
) -> tuple[list[str], Iterator[requests.Response]]:
|
||||
store = tcpstore_client()
|
||||
keys = [f"rank{r}" for r in range(get_world_size())]
|
||||
addrs = store.multi_get(keys)
|
||||
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
resps = executor.map(requests.post, addrs)
|
||||
|
||||
return addrs, resps
|
||||
|
||||
|
||||
def format_json(blob: str):
|
||||
parsed = json.loads(blob)
|
||||
return json.dumps(parsed, indent=2)
|
||||
|
||||
|
||||
templates = {
|
||||
"base.html": """
|
||||
<!doctype html>
|
||||
<head>
|
||||
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
|
||||
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
|
||||
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
font-family:
|
||||
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
|
||||
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
|
||||
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
line-height: 1.5;
|
||||
color: #212529;
|
||||
text-align: left;
|
||||
background-color: #fff;
|
||||
}
|
||||
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
|
||||
margin-bottom: .5rem;
|
||||
font-weight: 500;
|
||||
line-height: 1.2;
|
||||
}
|
||||
nav {
|
||||
background-color: rgba(0, 0, 0, 0.17);
|
||||
padding: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
nav h1 {
|
||||
display: inline-block;
|
||||
margin: 0;
|
||||
}
|
||||
nav a {
|
||||
margin: 0 8px;
|
||||
}
|
||||
section {
|
||||
max-width: 1280px;
|
||||
padding: 16px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
max-width: 100%;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<nav>
|
||||
<h1>Torch Distributed Debug Server</h1>
|
||||
|
||||
<a href="/">Home</a>
|
||||
<a href="/stacks">Python Stack Traces</a>
|
||||
<a href="/fr_trace">FlightRecorder</a>
|
||||
<a href="/fr_trace_nccl">FlightRecorder NCCL</a>
|
||||
<a href="/profile">torch profiler</a>
|
||||
</nav>
|
||||
|
||||
<section class="content">
|
||||
{% block header %}{% endblock %}
|
||||
{% for message in get_flashed_messages() %}
|
||||
<div class="flash">{{ message }}</div>
|
||||
{% endfor %}
|
||||
{% block content %}{% endblock %}
|
||||
</section>
|
||||
""",
|
||||
"index.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}Index{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
Hi
|
||||
{% endblock %}
|
||||
""",
|
||||
"raw_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{title}}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"json_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{ title }}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ format_json(resp.text) }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"profile.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}torch.profiler{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<form action="/profile" method="get">
|
||||
<label for="duration">Duration (seconds):</label>
|
||||
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
|
||||
<input type="submit" value="Submit">
|
||||
</form>
|
||||
|
||||
<script>
|
||||
function stringToArrayBuffer(str) {
|
||||
const encoder = new TextEncoder();
|
||||
return encoder.encode(str).buffer;
|
||||
}
|
||||
async function openPerfetto(data) {
|
||||
const ui = window.open('https://ui.perfetto.dev/#!/');
|
||||
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
|
||||
|
||||
// Perfetto readiness handshake: PING until we receive PONG
|
||||
await new Promise((resolve, reject) => {
|
||||
const onMsg = (e) => {
|
||||
if (e.source === ui && e.data === 'PONG') {
|
||||
window.removeEventListener('message', onMsg);
|
||||
clearInterval(pinger);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
window.addEventListener('message', onMsg);
|
||||
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
|
||||
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
|
||||
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
|
||||
|
||||
ui.postMessage({
|
||||
perfetto: {
|
||||
buffer: stringToArrayBuffer(JSON.stringify(data)),
|
||||
title: "torch profiler",
|
||||
fileName: "trace.json",
|
||||
}
|
||||
}, '*');
|
||||
}
|
||||
</script>
|
||||
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<script>
|
||||
function run{{ i }}() {
|
||||
var data = {{ resp.text | safe }};
|
||||
openPerfetto(data);
|
||||
}
|
||||
</script>
|
||||
|
||||
<button onclick="run{{ i }}()">View {{ i }}</button>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
}
|
||||
|
||||
app = Flask(__name__)
|
||||
app.jinja_loader = DictLoader(templates)
|
||||
app.jinja_env.globals.update(
|
||||
zip=zip,
|
||||
format_json=format_json,
|
||||
enumerate=enumerate,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def _index_handler():
|
||||
return render_template("index.html")
|
||||
|
||||
|
||||
@app.route("/stacks")
|
||||
def _stacks_handler():
|
||||
addrs, resps = fetch_all("dump_traceback")
|
||||
return render_template("raw_resp.html", title="Stacks", addrs=addrs, resps=resps)
|
||||
|
||||
|
||||
@app.route("/fr_trace")
|
||||
def _fr_trace_handler():
|
||||
addrs, resps = fetch_all("fr_trace_json")
|
||||
|
||||
return render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/fr_trace_nccl")
|
||||
def _fr_trace_nccl_handler():
|
||||
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
|
||||
|
||||
return render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder NCCL",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/profile")
|
||||
def _profiler_handler():
|
||||
duration = request.args.get("duration", default=1.0, type=float)
|
||||
|
||||
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
|
||||
|
||||
return render_template("profile.html", addrs=addrs, resps=resps)
|
||||
|
||||
|
||||
def main(port: int) -> None:
|
||||
app.run(host="::", port=port)
|
||||
22
torch/distributed/debug/_handlers.py
Normal file
22
torch/distributed/debug/_handlers.py
Normal file
@ -0,0 +1,22 @@
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from torch._C._distributed_c10d import _register_handler, _Request, _Response
|
||||
from torch.profiler import _ExperimentalConfig, profile
|
||||
|
||||
|
||||
def _torch_profile(req: _Request, resp: _Response) -> None:
|
||||
experimental_config = _ExperimentalConfig(
|
||||
profile_all_threads=True,
|
||||
)
|
||||
duration = float(req.get_param("duration"))
|
||||
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
|
||||
time.sleep(duration)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
|
||||
prof.export_chrome_trace(f.name)
|
||||
resp.set_content(open(f.name, "rb").read(), "application/json")
|
||||
resp.set_status(200)
|
||||
|
||||
|
||||
_register_handler("torch_profile", _torch_profile)
|
||||
24
torch/distributed/debug/_store.py
Normal file
24
torch/distributed/debug/_store.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return int(os.environ["RANK"])
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
|
||||
|
||||
def tcpstore_client() -> dist.Store:
|
||||
MASTER_ADDR = os.environ["MASTER_ADDR"]
|
||||
MASTER_PORT = int(os.environ["MASTER_PORT"])
|
||||
|
||||
store = dist.TCPStore(
|
||||
host_name=MASTER_ADDR,
|
||||
port=MASTER_PORT,
|
||||
is_master=False,
|
||||
)
|
||||
store = dist.PrefixStore("debug_server", store)
|
||||
return store
|
||||
Reference in New Issue
Block a user