[torchelastic] add timing events to different stages of rendezvous (#125636)

Summary: as title

Test Plan: unit tests. Launched a test job and observed scuba results: {F1506543300}

Differential Revision: D57018103

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125636
Approved by: https://github.com/d4l3k
This commit is contained in:
Michael Suo
2024-05-08 01:14:23 +00:00
committed by PyTorch MergeBot
parent a3d97f6ce4
commit 21aaac47e7

View File

@ -14,7 +14,7 @@ import socket
import time
import traceback
import warnings
from contextlib import closing
from contextlib import closing, contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -545,10 +545,12 @@ class SimpleElasticAgent(ElasticAgent):
"""
spec = worker_group.spec
store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
with self.record_duration("RENDEZVOUS"):
store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
self._store = store
workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
with self.record_duration("ASSIGN_WORKER_RANKS"):
workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
worker_group.workers = workers
worker_group.store = store
worker_group.group_rank = group_rank
@ -562,7 +564,8 @@ class SimpleElasticAgent(ElasticAgent):
spec.local_addr,
)
master_addr, master_port = self._get_master_addr_port(store)
with self.record_duration("GET_MASTER_ADDR_PORT"):
master_addr, master_port = self._get_master_addr_port(store)
restart_count = spec.max_restarts - self._remaining_restarts
logger.info(
@ -781,12 +784,23 @@ class SimpleElasticAgent(ElasticAgent):
else:
raise ValueError(f"Unknown worker: {worker.global_rank}")
@contextmanager
def record_duration(self, state: str):
start_time = time.perf_counter()
try:
yield
finally:
end_time = time.perf_counter()
duration_ms = (end_time - start_time) * 1000
record(self._construct_event(state=state, source=EventSource.AGENT, duration_ms=duration_ms))
def _construct_event(
self,
state: str,
source: EventSource,
worker: Optional[Worker] = None,
raw_error: Optional[str] = None,
duration_ms: Optional[float] = None,
) -> Event:
wg = self._worker_group
spec = wg.spec
@ -817,6 +831,7 @@ class SimpleElasticAgent(ElasticAgent):
"raw_error": raw_error,
"metadata": md_str,
"agent_restarts": spec.max_restarts - self._remaining_restarts,
"duration_ms": duration_ms,
}
return Event(
f"torchelastic.worker.status.{state}", source=source, metadata=metadata