Files
pytorch/torch/distributed/elastic/control_plane.py
Tristan Rice 597922ba21 Reapply "distributed debug handlers (#126601)" (#127805)
This reverts commit 7646825c3eb687030c4f873b01312be0eed80174.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127805
Approved by: https://github.com/PaliC
2024-06-04 19:44:30 +00:00

52 lines
1.1 KiB
Python

import os
from contextlib import contextmanager, ExitStack
from typing import Generator
from torch.distributed.elastic.multiprocessing.errors import record
__all__ = [
"worker_main",
]
TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
@contextmanager
def _worker_server(socket_path: str) -> Generator[None, None, None]:
from torch._C._distributed_c10d import _WorkerServer
server = _WorkerServer(socket_path)
try:
yield
finally:
server.shutdown()
@contextmanager
@record
def worker_main() -> Generator[None, None, None]:
"""
This is a context manager that wraps your main entry function. This combines
the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
exposes handlers via a unix socket specified by
``Torch_WORKER_SERVER_SOCKET``.
Example
::
@worker_main()
def main():
pass
if __name__=="__main__":
main()
"""
with ExitStack() as stack:
socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
if socket_path is not None:
stack.enter_context(_worker_server(socket_path))
yield