mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: The original decorator based time logger is bad in performance and capacity. So we want to replace it with pytorch `_WaitCounter` now. Test Plan: Tested on workload and no regression has been seen: https://fburl.com/scuba/aps_instrumentation_components/mskj73ea Differential Revision: D66218675 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141154 Approved by: https://github.com/wz337
154 lines
4.7 KiB
Python
154 lines
4.7 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
from functools import partial, wraps
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed.c10d_logger import _c10d_logger, _exception_logger
|
|
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
|
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
BACKEND = dist.Backend.NCCL
|
|
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
|
|
|
|
|
|
def with_comms(func=None):
|
|
if func is None:
|
|
return partial(
|
|
with_comms,
|
|
)
|
|
|
|
@wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
|
self.dist_init()
|
|
func(self)
|
|
self.destroy_comms()
|
|
|
|
return wrapper
|
|
|
|
|
|
class C10dErrorLoggerTest(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
os.environ["BACKEND"] = BACKEND
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def device(self):
|
|
return (
|
|
torch.device(self.rank)
|
|
if BACKEND == dist.Backend.NCCL
|
|
else torch.device("cpu")
|
|
)
|
|
|
|
@property
|
|
def world_size(self):
|
|
return WORLD_SIZE
|
|
|
|
@property
|
|
def process_group(self):
|
|
return dist.group.WORLD
|
|
|
|
def destroy_comms(self):
|
|
# Wait for all ranks to reach here before starting shutdown.
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
def dist_init(self):
|
|
dist.init_process_group(
|
|
backend=BACKEND,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
init_method=f"file://{self.file_name}",
|
|
)
|
|
|
|
# set device for nccl pg for collectives
|
|
if BACKEND == "nccl":
|
|
torch.cuda.set_device(self.rank)
|
|
|
|
def test_get_or_create_logger(self):
|
|
self.assertIsNotNone(_c10d_logger)
|
|
self.assertEqual(1, len(_c10d_logger.handlers))
|
|
self.assertIsInstance(_c10d_logger.handlers[0], logging.NullHandler)
|
|
|
|
@_exception_logger
|
|
def _failed_broadcast_raise_exception(self):
|
|
tensor = torch.arange(2, dtype=torch.int64)
|
|
dist.broadcast(tensor, self.world_size + 1)
|
|
|
|
@_exception_logger
|
|
def _failed_broadcast_not_raise_exception(self):
|
|
try:
|
|
tensor = torch.arange(2, dtype=torch.int64)
|
|
dist.broadcast(tensor, self.world_size + 1)
|
|
except Exception:
|
|
pass
|
|
|
|
@with_comms
|
|
def test_exception_logger(self) -> None:
|
|
with self.assertRaises(Exception):
|
|
self._failed_broadcast_raise_exception()
|
|
|
|
with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
|
|
self._failed_broadcast_not_raise_exception()
|
|
error_msg_dict = json.loads(
|
|
re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
|
|
)
|
|
|
|
self.assertEqual(len(error_msg_dict), 9)
|
|
|
|
self.assertIn("pg_name", error_msg_dict.keys())
|
|
self.assertEqual("None", error_msg_dict["pg_name"])
|
|
|
|
self.assertIn("func_name", error_msg_dict.keys())
|
|
self.assertEqual("broadcast", error_msg_dict["func_name"])
|
|
|
|
self.assertIn("backend", error_msg_dict.keys())
|
|
self.assertEqual("nccl", error_msg_dict["backend"])
|
|
|
|
self.assertIn("nccl_version", error_msg_dict.keys())
|
|
nccl_ver = torch.cuda.nccl.version()
|
|
self.assertEqual(
|
|
".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
|
|
)
|
|
|
|
# In this test case, group_size = world_size, since we don't have multiple processes on one node.
|
|
self.assertIn("group_size", error_msg_dict.keys())
|
|
self.assertEqual(str(self.world_size), error_msg_dict["group_size"])
|
|
|
|
self.assertIn("world_size", error_msg_dict.keys())
|
|
self.assertEqual(str(self.world_size), error_msg_dict["world_size"])
|
|
|
|
self.assertIn("global_rank", error_msg_dict.keys())
|
|
self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"])
|
|
|
|
# In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
|
|
self.assertIn("local_rank", error_msg_dict.keys())
|
|
self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|