mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Prevent rendezvous shutdown on worker restarts (#124819)
Fixes #123678 #### Summary When the rank leaves and joins back, the workers are restarted and while restarting the rendezvous is shut down. This change prevents rendezvous shutdown during worker restarts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124819 Approved by: https://github.com/malfet, https://github.com/kurman, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
6c4f43f826
commit
0bde9c08ef
@ -85,6 +85,16 @@ def dummy_compute() -> torch.Tensor:
|
||||
return torch.rand(100, 100)
|
||||
|
||||
|
||||
def dummy_compute_simulate_rank_failure() -> torch.Tensor:
|
||||
"""
|
||||
fails rank 1 once
|
||||
in other cases, returns a predefined size random Tensor
|
||||
"""
|
||||
if os.environ["RANK"] == "1" and os.environ["TORCHELASTIC_RESTART_COUNT"] == "0":
|
||||
os.kill(os.getpid(), 9)
|
||||
return torch.rand(100, 100)
|
||||
|
||||
|
||||
def _fatal_signal_function(expected_error_index: int, sig: int):
|
||||
rank = int(os.environ["RANK"])
|
||||
if rank == expected_error_index:
|
||||
@ -1440,3 +1450,19 @@ class LocalElasticAgentTest(unittest.TestCase):
|
||||
)
|
||||
def test_shutdown_called_etcd_v2(self):
|
||||
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called)
|
||||
|
||||
def fail_rank_one_once(self):
|
||||
res = self.run_agent(
|
||||
Conf(entrypoint=dummy_compute_simulate_rank_failure, local_world_size=2),
|
||||
max_restarts=3,
|
||||
)
|
||||
self.assertFalse(res.is_failed())
|
||||
for return_value in res.return_values.values():
|
||||
self.assertIsInstance(return_value, torch.Tensor)
|
||||
self.assertEqual((100, 100), return_value.shape)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
def test_rank_restart_after_failure(self):
|
||||
self.run_test_with_backend(backend="c10d", test_to_run=self.fail_rank_one_once)
|
||||
|
@ -480,7 +480,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stop_workers(self, worker_group: WorkerGroup) -> None:
|
||||
def _stop_workers(self, worker_group: WorkerGroup, is_restart: bool = False) -> None:
|
||||
r"""Stop all workers in the given worker group.
|
||||
|
||||
Implementors must deal with workers in all states defined by
|
||||
@ -498,7 +498,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
|
||||
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False) -> None:
|
||||
"""Clean up any resources that were allocated during the agent's work.
|
||||
|
||||
Args:
|
||||
@ -725,7 +725,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
"""Restart (stops, rendezvous, starts) all local workers in the group."""
|
||||
role = worker_group.spec.role
|
||||
logger.info("[%s] Stopping worker group", role)
|
||||
self._stop_workers(worker_group)
|
||||
self._stop_workers(worker_group, is_restart=True)
|
||||
worker_group.state = WorkerState.STOPPED
|
||||
self._initialize_workers(worker_group)
|
||||
|
||||
|
@ -263,8 +263,8 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
||||
# `torch.distributed.elastic.metrics.prof`.
|
||||
@prof
|
||||
def _stop_workers(self, worker_group: WorkerGroup) -> None:
|
||||
self._shutdown()
|
||||
def _stop_workers(self, worker_group: WorkerGroup, is_restart: bool = False) -> None:
|
||||
self._shutdown(is_restart=is_restart)
|
||||
|
||||
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
||||
# `torch.distributed.elastic.metrics.prof`.
|
||||
@ -336,7 +336,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
|
||||
return self._pcontext.pids()
|
||||
|
||||
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
|
||||
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False) -> None:
|
||||
if self._worker_watchdog is not None:
|
||||
self._worker_watchdog.stop()
|
||||
self._worker_watchdog = None
|
||||
@ -345,7 +345,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
self._health_check_server = None
|
||||
if self._pcontext:
|
||||
self._pcontext.close(death_sig)
|
||||
if self._rdzv_handler:
|
||||
if not is_restart and self._rdzv_handler:
|
||||
self._rdzv_handler.shutdown()
|
||||
|
||||
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
||||
|
Reference in New Issue
Block a user