mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[torchelastic] add timing events to different stages of rendezvous (#125636)
Summary: as title Test Plan: unit tests. Launched a test job and observed scuba results: {F1506543300} Differential Revision: D57018103 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125636 Approved by: https://github.com/d4l3k
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3d97f6ce4
commit
21aaac47e7
@ -14,7 +14,7 @@ import socket
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from contextlib import closing
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
@ -545,10 +545,12 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
"""
|
||||
spec = worker_group.spec
|
||||
|
||||
store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
|
||||
with self.record_duration("RENDEZVOUS"):
|
||||
store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
|
||||
self._store = store
|
||||
|
||||
workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
|
||||
with self.record_duration("ASSIGN_WORKER_RANKS"):
|
||||
workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
|
||||
worker_group.workers = workers
|
||||
worker_group.store = store
|
||||
worker_group.group_rank = group_rank
|
||||
@ -562,7 +564,8 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
spec.local_addr,
|
||||
)
|
||||
|
||||
master_addr, master_port = self._get_master_addr_port(store)
|
||||
with self.record_duration("GET_MASTER_ADDR_PORT"):
|
||||
master_addr, master_port = self._get_master_addr_port(store)
|
||||
restart_count = spec.max_restarts - self._remaining_restarts
|
||||
|
||||
logger.info(
|
||||
@ -781,12 +784,23 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
else:
|
||||
raise ValueError(f"Unknown worker: {worker.global_rank}")
|
||||
|
||||
@contextmanager
|
||||
def record_duration(self, state: str):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
duration_ms = (end_time - start_time) * 1000
|
||||
record(self._construct_event(state=state, source=EventSource.AGENT, duration_ms=duration_ms))
|
||||
|
||||
def _construct_event(
|
||||
self,
|
||||
state: str,
|
||||
source: EventSource,
|
||||
worker: Optional[Worker] = None,
|
||||
raw_error: Optional[str] = None,
|
||||
duration_ms: Optional[float] = None,
|
||||
) -> Event:
|
||||
wg = self._worker_group
|
||||
spec = wg.spec
|
||||
@ -817,6 +831,7 @@ class SimpleElasticAgent(ElasticAgent):
|
||||
"raw_error": raw_error,
|
||||
"metadata": md_str,
|
||||
"agent_restarts": spec.max_restarts - self._remaining_restarts,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
return Event(
|
||||
f"torchelastic.worker.status.{state}", source=source, metadata=metadata
|
||||
|
Reference in New Issue
Block a user