Files
pytorch/test/distributed/launcher/api_test.py
Jane Xu 251278d385 [skip ci] set more tests with owners for distributed and elastic (#67583)
Summary:
It turns out my lint doesn't work on CI all the time because of shell differences. I'm working on a new more comprehensive lint in https://github.com/pytorch/pytorch/pull/66826 and it'd be nice if these could be cleared first.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67583

Reviewed By: H-Huang, mruberry

Differential Revision: D32045155

Pulled By: janeyx99

fbshipit-source-id: ecfe9f008310c28e3b731e246c2b2ed0106d03b1
2021-11-01 12:26:03 -07:00

341 lines
11 KiB
Python

#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing as mp
import os
import shutil
import sys
import tempfile
import time
import unittest
import uuid
from contextlib import closing
from typing import Optional, Any, Dict
from unittest import mock
from unittest.mock import Mock, patch
import torch
import torch.distributed as dist
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.launcher.api import (
LaunchConfig,
elastic_launch,
_get_entrypoint_name,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN,
sandcastle_skip_if,
)
def path(script):
return os.path.join(os.path.dirname(__file__), script)
def simple_rank_scale():
rank = int(os.environ["RANK"])
return 10 + rank
def function_with_bug():
raise RuntimeError("test error")
def _dist_sum(wait=0):
rank = int(os.environ["RANK"])
dist.init_process_group(backend="gloo")
t = torch.tensor(rank)
time.sleep(wait)
dist.all_reduce(t, op=dist.reduce_op.SUM)
return t.item()
class MockException(Exception):
pass
class ElasticLaunchTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# start a standalone, single process etcd server to use for all tests.
cls._etcd_server = EtcdServer()
cls._etcd_server.start()
cls._etcd_endpoint = cls._etcd_server.get_endpoint()
@classmethod
def tearDownClass(cls):
# stop the standalone etcd server.
cls._etcd_server.stop()
def setUp(self):
self.test_dir = tempfile.mkdtemp()
# remove any lingering environment variables.
for env in os.environ.keys():
if env.startswith("PET_"):
del os.environ[env]
# set a sentinel env var on the parent proc.
# this should be present on the child and gets
# asserted in ``bin/test_script.py``.
os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR"
os.environ["OMP_NUM_THREADS"] = str(1)
def tearDown(self):
shutil.rmtree(self.test_dir)
def get_test_launch_config(
self,
min_nodes: int,
max_nodes: int,
nproc_per_node: int,
run_id: str = "",
rdzv_backend: str = "etcd",
config: Optional[Dict[str, Any]] = None,
rdzv_endpoint: Optional[str] = None,
) -> LaunchConfig:
rdzv_configs = {}
if config:
rdzv_configs.update(config)
endpoint = self._etcd_endpoint
if rdzv_endpoint:
endpoint = rdzv_endpoint
return LaunchConfig(
min_nodes=min_nodes,
max_nodes=max_nodes,
nproc_per_node=nproc_per_node,
run_id=run_id,
rdzv_endpoint=endpoint,
monitor_interval=1,
rdzv_backend=rdzv_backend,
start_method="spawn",
max_restarts=0,
rdzv_configs=rdzv_configs,
)
def check_works_ran(self, world_size: int):
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)
@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_script_python(self):
nnodes = 1
nproc_per_node = 4
elastic_launch(
self.get_test_launch_config(nnodes, nnodes, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}")
# make sure all the workers ran.
# each worker touches a file with its global rank as the name.
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)
@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_script_python_local_rank_transfer(self):
nnodes = 1
nproc_per_node = 4
elastic_launch(
self.get_test_launch_config(nnodes, nnodes, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}")
# make sure all the workers ran.
# each worker touches a file with its global rank as the name.
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)
@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_script_bash(self):
nnodes = 1
nproc_per_node = 4
elastic_launch(
self.get_test_launch_config(nnodes, nnodes, nproc_per_node),
path("bin/test_script.sh"),
)(f"{self.test_dir}")
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)
@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_function(self):
nnodes = 1
nproc_per_node = 4
res = elastic_launch(
self.get_test_launch_config(nnodes, nnodes, nproc_per_node),
simple_rank_scale,
)()
expected_res = [10, 11, 12, 13]
actual_res = sorted(value for value in res.values())
self.assertEqual(expected_res, actual_res)
@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_dist_sum_with_static_rdzv(self):
nnodes = 1
nproc_per_node = 4
sock = get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]
rdzv_endpoint = f"127.0.0.1:{master_port}"
rank = 0
rdzv_config = {
"rank": rank,
}
res = elastic_launch(
self.get_test_launch_config(
nnodes,
nnodes,
nproc_per_node,
rdzv_backend="static",
config=rdzv_config,
rdzv_endpoint=rdzv_endpoint,
),
_dist_sum,
)()
expected_res = [sum(range(nproc_per_node))] * nproc_per_node
actual_res = sorted(value for value in res.values())
self.assertEqual(expected_res, actual_res)
@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_elastic(self):
nproc_per_node = 4
elastic_launch(
self.get_test_launch_config(1, 2, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}")
world_size = nproc_per_node
self.check_works_ran(world_size)
@mock.patch("torch.distributed.elastic.events.record")
def test_launch_elastic_worker_raise_exception(self, record_mock):
"""
Asserts that when the worker program fails and lancher raieses exception
to indicate that worker process failed.
"""
nproc_per_node = 4
with self.assertRaises(ChildFailedError):
elastic_launch(
self.get_test_launch_config(1, 2, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), "--fail")
record_mock.assert_called_once()
@mock.patch("torch.distributed.elastic.events.record")
@mock.patch(
"torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run"
)
def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run):
"""
Asserts that when the agent raises an exception
the launcher re-raises the original exception.
"""
mock_agent_run.side_effect = MockException
with self.assertRaises(MockException):
elastic_launch(
self.get_test_launch_config(1, 2, 4),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}")
record_mock.assert_called_once()
@unittest.skipIf(TEST_WITH_TSAN, "test incompatible with tsan")
def test_launch_elastic_multiple_agents(self):
min_nodes = 1
max_nodes = 2
nproc_per_node = 4
nnodes = 2
run_id = str(uuid.uuid4().int)
def elastic_launch_wrapper():
"""We need a wrapper function for class `elastic_launch.` in order to make multiprocess returns correct exit code."""
elastic_launch(
self.get_test_launch_config(
min_nodes, max_nodes, nproc_per_node, run_id
),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}")
procs = []
for _ in range(nnodes - 1):
p = mp.Process(
target=elastic_launch_wrapper,
)
procs.append(p)
p.start()
elastic_launch_wrapper()
for i in range(nnodes - 1):
p = procs[i]
p.join()
self.assertEqual(0, p.exitcode)
# make sure all the workers ran
# each worker touches a file with its global rank as the name
world_size = nnodes * nproc_per_node
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)
@patch("torch.distributed.launcher.api.LocalElasticAgent")
def test_launch_shutdown(self, agent_mock_cls):
agent_mock = Mock()
agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED)
agent_mock_cls.return_value = agent_mock
rdzv_handler_mock = Mock()
with patch(
"torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler"
) as param_mock:
param_mock.return_value = rdzv_handler_mock
elastic_launch(
self.get_test_launch_config(1, 1, 4),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}")
rdzv_handler_mock.shutdown.assert_called_once()
def test_get_entrypoint_name(self):
self.assertEqual(
"simple_rank_scale", _get_entrypoint_name(simple_rank_scale, [])
)
self.assertEqual("", _get_entrypoint_name(sys.executable, []))
self.assertEqual("", _get_entrypoint_name(sys.executable, ["-u"]))
self.assertEqual(
"test_script.py",
_get_entrypoint_name(sys.executable, ["-u", "test_script.py"]),
)
self.assertEqual("", _get_entrypoint_name(None, []))