logging start of torch elastic workers. (#150849)

Summary:
We would like to log start of the workers. It will help with complete logging.

Test Plan:
unit tests

https://www.internalfb.com/intern/testinfra/testrun/6473924724652056

e2e tests
https://www.internalfb.com/mlhub/pipelines/runs/mast/f712311762-27449483648-TrainingApplication_V403K?job_attempt=0&version=0&tab=execution_details&env=PRODUCTION

Reviewed By: tnykiel

Differential Revision: D72297314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150849
Approved by: https://github.com/d4l3k, https://github.com/kiukchung
This commit is contained in:
Amandeep Chhabra
2025-04-22 22:35:04 +00:00
committed by PyTorch MergeBot
parent 6a1b820255
commit a7ccd96bbf
2 changed files with 7 additions and 1 deletions

View File

@ -350,7 +350,8 @@ class SimpleElasticAgentTest(unittest.TestCase):
self.assertEqual(spec_local_addr, worker_group.master_addr)
self.assertGreater(worker_group.master_port, 0)
def test_initialize_workers(self):
@patch.object(TestAgent, "_construct_event")
def test_initialize_workers(self, mock_construct_event):
spec = self._get_worker_spec(max_restarts=1)
agent = TestAgent(spec)
worker_group = agent.get_worker_group()
@ -361,6 +362,9 @@ class SimpleElasticAgentTest(unittest.TestCase):
worker = worker_group.workers[i]
self.assertEqual(worker.id, worker.global_rank)
mock_construct_event.assert_called()
self.assertEqual(mock_construct_event.call_count, 10)
def test_restart_workers(self):
spec = self._get_worker_spec()
agent = TestAgent(spec)

View File

@ -687,6 +687,7 @@ class SimpleElasticAgent(ElasticAgent):
for local_rank, w_id in worker_ids.items():
worker = worker_group.workers[local_rank]
worker.id = w_id
record(self._construct_event("START", EventSource.WORKER, worker))
worker_group.state = WorkerState.HEALTHY
@ -809,6 +810,7 @@ class SimpleElasticAgent(ElasticAgent):
"agent_restarts": spec.max_restarts - self._remaining_restarts,
"duration_ms": duration_ms,
}
return Event(
f"torchelastic.worker.status.{state}", source=source, metadata=metadata
)