mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Some machines don't have a versionless `python` on their PATH, which breaks these existing shebangs. I'm assuming that all the existing versionless `python` shebangs are meant to be `python3` and not `python2`; please let me know if my assumption was incorrect for any of these. Pull Request resolved: https://github.com/pytorch/pytorch/pull/58275 Test Plan: CI. Reviewed By: zhouzhuojie Differential Revision: D28428143 Pulled By: samestep fbshipit-source-id: 6562be3d12924db72a92a0207b060ef740f61ebf
707 lines
25 KiB
Python
707 lines
25 KiB
Python
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from hypothesis import given, settings
|
|
import hypothesis.strategies as st
|
|
from multiprocessing import Process, Queue
|
|
|
|
import numpy as np
|
|
import os
|
|
import pickle
|
|
import tempfile
|
|
import shutil
|
|
|
|
from caffe2.python import core, workspace, dyndep
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
from gloo.python import IoError
|
|
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:redis_store_handler_ops")
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:store_ops")
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/gloo:gloo_ops")
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/gloo:gloo_ops_gpu")
|
|
|
|
op_engine = 'GLOO'
|
|
|
|
class TemporaryDirectory:
|
|
def __enter__(self):
|
|
self.tmpdir = tempfile.mkdtemp()
|
|
return self.tmpdir
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
shutil.rmtree(self.tmpdir)
|
|
|
|
|
|
class TestCase(hu.HypothesisTestCase):
|
|
test_counter = 0
|
|
sync_counter = 0
|
|
|
|
def run_test_locally(self, fn, device_option=None, **kwargs):
|
|
# Queue for assertion errors on subprocesses
|
|
queue = Queue()
|
|
|
|
# Capture any exception thrown by the subprocess
|
|
def run_fn(*args, **kwargs):
|
|
try:
|
|
with core.DeviceScope(device_option):
|
|
fn(*args, **kwargs)
|
|
workspace.ResetWorkspace()
|
|
queue.put(True)
|
|
except Exception as ex:
|
|
queue.put(ex)
|
|
|
|
# Start N processes in the background
|
|
procs = []
|
|
for i in range(kwargs['comm_size']):
|
|
kwargs['comm_rank'] = i
|
|
proc = Process(
|
|
target=run_fn,
|
|
kwargs=kwargs)
|
|
proc.start()
|
|
procs.append(proc)
|
|
|
|
# Test complete, join background processes
|
|
while len(procs) > 0:
|
|
proc = procs.pop(0)
|
|
while proc.is_alive():
|
|
proc.join(10)
|
|
|
|
# Raise exception if we find any. Otherwise each worker
|
|
# should put a True into the queue
|
|
# Note that the following is executed ALSO after
|
|
# the last process was joined, so if ANY exception
|
|
# was raised, it will be re-raised here.
|
|
self.assertFalse(queue.empty(), "Job failed without a result")
|
|
o = queue.get()
|
|
if isinstance(o, Exception):
|
|
raise o
|
|
else:
|
|
self.assertTrue(o)
|
|
|
|
def run_test_distributed(self, fn, device_option=None, **kwargs):
|
|
comm_rank = os.getenv('COMM_RANK')
|
|
self.assertIsNotNone(comm_rank)
|
|
comm_size = os.getenv('COMM_SIZE')
|
|
self.assertIsNotNone(comm_size)
|
|
kwargs['comm_rank'] = int(comm_rank)
|
|
kwargs['comm_size'] = int(comm_size)
|
|
with core.DeviceScope(device_option):
|
|
fn(**kwargs)
|
|
workspace.ResetWorkspace()
|
|
|
|
def create_common_world(self, comm_rank, comm_size, tmpdir=None, existing_cw=None):
|
|
store_handler = "store_handler"
|
|
|
|
# If REDIS_HOST is set, use RedisStoreHandler for rendezvous.
|
|
if existing_cw is None:
|
|
redis_host = os.getenv("REDIS_HOST")
|
|
redis_port = int(os.getenv("REDIS_PORT", 6379))
|
|
if redis_host is not None:
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"RedisStoreHandlerCreate",
|
|
[],
|
|
[store_handler],
|
|
prefix=str(TestCase.test_counter) + "/",
|
|
host=redis_host,
|
|
port=redis_port))
|
|
else:
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"FileStoreHandlerCreate",
|
|
[],
|
|
[store_handler],
|
|
path=tmpdir))
|
|
common_world = "common_world"
|
|
else:
|
|
common_world = str(existing_cw) + ".forked"
|
|
|
|
if existing_cw is not None:
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"CloneCommonWorld",
|
|
[existing_cw],
|
|
[common_world],
|
|
sync=True,
|
|
engine=op_engine))
|
|
else:
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"CreateCommonWorld",
|
|
[store_handler],
|
|
[common_world],
|
|
size=comm_size,
|
|
rank=comm_rank,
|
|
sync=True,
|
|
engine=op_engine))
|
|
return (store_handler, common_world)
|
|
|
|
def synchronize(self, store_handler, value, comm_rank=None):
|
|
TestCase.sync_counter += 1
|
|
blob = "sync_{}".format(TestCase.sync_counter)
|
|
if comm_rank == 0:
|
|
workspace.FeedBlob(blob, pickle.dumps(value))
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"StoreSet",
|
|
[store_handler, blob],
|
|
[]))
|
|
else:
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"StoreGet",
|
|
[store_handler],
|
|
[blob]))
|
|
return pickle.loads(workspace.FetchBlob(blob))
|
|
|
|
def _test_broadcast(self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
blob_size=None,
|
|
num_blobs=None,
|
|
tmpdir=None,
|
|
use_float16=False,
|
|
):
|
|
store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank,
|
|
comm_size=comm_size,
|
|
tmpdir=tmpdir)
|
|
|
|
blob_size = self.synchronize(
|
|
store_handler,
|
|
blob_size,
|
|
comm_rank=comm_rank)
|
|
|
|
num_blobs = self.synchronize(
|
|
store_handler,
|
|
num_blobs,
|
|
comm_rank=comm_rank)
|
|
|
|
for i in range(comm_size):
|
|
blobs = []
|
|
for j in range(num_blobs):
|
|
blob = "blob_{}".format(j)
|
|
offset = (comm_rank * num_blobs) + j
|
|
value = np.full(blob_size, offset,
|
|
np.float16 if use_float16 else np.float32)
|
|
workspace.FeedBlob(blob, value)
|
|
blobs.append(blob)
|
|
|
|
net = core.Net("broadcast")
|
|
net.Broadcast(
|
|
[common_world] + blobs,
|
|
blobs,
|
|
root=i,
|
|
engine=op_engine)
|
|
|
|
workspace.CreateNet(net)
|
|
workspace.RunNet(net.Name())
|
|
|
|
for j in range(num_blobs):
|
|
np.testing.assert_array_equal(
|
|
workspace.FetchBlob(blobs[j]),
|
|
i * num_blobs)
|
|
|
|
# Run the net a few more times to check the operator
|
|
# works not just the first time it's called
|
|
for _tmp in range(4):
|
|
workspace.RunNet(net.Name())
|
|
|
|
@given(comm_size=st.integers(min_value=2, max_value=8),
|
|
blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
|
|
num_blobs=st.integers(min_value=1, max_value=4),
|
|
device_option=st.sampled_from([hu.cpu_do]),
|
|
use_float16=st.booleans())
|
|
@settings(deadline=10000)
|
|
def test_broadcast(self, comm_size, blob_size, num_blobs, device_option,
|
|
use_float16):
|
|
TestCase.test_counter += 1
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_broadcast,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
use_float16=use_float16,
|
|
device_option=device_option)
|
|
else:
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_broadcast,
|
|
comm_size=comm_size,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir,
|
|
use_float16=use_float16)
|
|
|
|
def _test_allreduce(self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
blob_size=None,
|
|
num_blobs=None,
|
|
tmpdir=None,
|
|
use_float16=False
|
|
):
|
|
store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank,
|
|
comm_size=comm_size,
|
|
tmpdir=tmpdir)
|
|
|
|
blob_size = self.synchronize(
|
|
store_handler,
|
|
blob_size,
|
|
comm_rank=comm_rank)
|
|
|
|
num_blobs = self.synchronize(
|
|
store_handler,
|
|
num_blobs,
|
|
comm_rank=comm_rank)
|
|
|
|
blobs = []
|
|
for i in range(num_blobs):
|
|
blob = "blob_{}".format(i)
|
|
value = np.full(blob_size, (comm_rank * num_blobs) + i,
|
|
np.float16 if use_float16 else np.float32)
|
|
workspace.FeedBlob(blob, value)
|
|
blobs.append(blob)
|
|
|
|
net = core.Net("allreduce")
|
|
net.Allreduce(
|
|
[common_world] + blobs,
|
|
blobs,
|
|
engine=op_engine)
|
|
|
|
workspace.CreateNet(net)
|
|
workspace.RunNet(net.Name())
|
|
|
|
for i in range(num_blobs):
|
|
np.testing.assert_array_equal(
|
|
workspace.FetchBlob(blobs[i]),
|
|
(num_blobs * comm_size) * (num_blobs * comm_size - 1) / 2)
|
|
|
|
# Run the net a few more times to check the operator
|
|
# works not just the first time it's called
|
|
for _tmp in range(4):
|
|
workspace.RunNet(net.Name())
|
|
|
|
def _test_allreduce_multicw(self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
tmpdir=None
|
|
):
|
|
_store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank,
|
|
comm_size=comm_size,
|
|
tmpdir=tmpdir)
|
|
|
|
_, common_world2 = self.create_common_world(
|
|
comm_rank=comm_rank,
|
|
comm_size=comm_size,
|
|
tmpdir=tmpdir,
|
|
existing_cw=common_world)
|
|
|
|
blob_size = int(1e4)
|
|
num_blobs = 4
|
|
|
|
for cw in [common_world, common_world2]:
|
|
blobs = []
|
|
for i in range(num_blobs):
|
|
blob = "blob_{}".format(i)
|
|
value = np.full(blob_size, (comm_rank * num_blobs) + i, np.float32)
|
|
workspace.FeedBlob(blob, value)
|
|
blobs.append(blob)
|
|
|
|
net = core.Net("allreduce_multicw")
|
|
net.Allreduce(
|
|
[cw] + blobs,
|
|
blobs,
|
|
engine=op_engine)
|
|
|
|
workspace.RunNetOnce(net)
|
|
for i in range(num_blobs):
|
|
np.testing.assert_array_equal(
|
|
workspace.FetchBlob(blobs[i]),
|
|
(num_blobs * comm_size) * (num_blobs * comm_size - 1) / 2)
|
|
|
|
@given(comm_size=st.integers(min_value=2, max_value=8),
|
|
blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
|
|
num_blobs=st.integers(min_value=1, max_value=4),
|
|
device_option=st.sampled_from([hu.cpu_do]),
|
|
use_float16=st.booleans())
|
|
@settings(deadline=10000)
|
|
def test_allreduce(self, comm_size, blob_size, num_blobs, device_option,
|
|
use_float16):
|
|
TestCase.test_counter += 1
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_allreduce,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
use_float16=use_float16,
|
|
device_option=device_option)
|
|
else:
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_allreduce,
|
|
comm_size=comm_size,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir,
|
|
use_float16=use_float16)
|
|
|
|
def _test_reduce_scatter(self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
blob_size=None,
|
|
num_blobs=None,
|
|
tmpdir=None,
|
|
use_float16=False
|
|
):
|
|
store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank,
|
|
comm_size=comm_size,
|
|
tmpdir=tmpdir)
|
|
|
|
blob_size = self.synchronize(
|
|
store_handler,
|
|
blob_size,
|
|
comm_rank=comm_rank)
|
|
|
|
num_blobs = self.synchronize(
|
|
store_handler,
|
|
num_blobs,
|
|
comm_rank=comm_rank)
|
|
|
|
blobs = []
|
|
for i in range(num_blobs):
|
|
blob = "blob_{}".format(i)
|
|
value = np.full(blob_size, (comm_rank * num_blobs) + i,
|
|
np.float16 if use_float16 else np.float32)
|
|
workspace.FeedBlob(blob, value)
|
|
blobs.append(blob)
|
|
|
|
# Specify distribution among ranks i.e. number of elements
|
|
# scattered/distributed to each process.
|
|
recv_counts = np.zeros(comm_size, dtype=np.int32)
|
|
remaining = blob_size
|
|
chunk_size = (blob_size + comm_size - 1) / comm_size
|
|
for i in range(comm_size):
|
|
recv_counts[i] = min(chunk_size, remaining)
|
|
remaining = remaining - chunk_size if remaining > chunk_size else 0
|
|
recv_counts_blob = "recvCounts"
|
|
workspace.FeedBlob(recv_counts_blob, recv_counts)
|
|
blobs.append(recv_counts_blob)
|
|
|
|
net = core.Net("reduce_scatter")
|
|
net.ReduceScatter(
|
|
[common_world] + blobs,
|
|
blobs,
|
|
engine=op_engine)
|
|
|
|
workspace.CreateNet(net)
|
|
workspace.RunNet(net.Name())
|
|
|
|
for i in range(num_blobs):
|
|
np.testing.assert_array_equal(
|
|
np.resize(workspace.FetchBlob(blobs[i]), recv_counts[comm_rank]),
|
|
(num_blobs * comm_size) * (num_blobs * comm_size - 1) / 2)
|
|
|
|
# Run the net a few more times to check the operator
|
|
# works not just the first time it's called
|
|
for _tmp in range(4):
|
|
workspace.RunNet(net.Name())
|
|
|
|
@given(comm_size=st.integers(min_value=2, max_value=8),
|
|
blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
|
|
num_blobs=st.integers(min_value=1, max_value=4),
|
|
device_option=st.sampled_from([hu.cpu_do]),
|
|
use_float16=st.booleans())
|
|
@settings(deadline=10000)
|
|
def test_reduce_scatter(self, comm_size, blob_size, num_blobs,
|
|
device_option, use_float16):
|
|
TestCase.test_counter += 1
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_reduce_scatter,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
use_float16=use_float16,
|
|
device_option=device_option)
|
|
else:
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_reduce_scatter,
|
|
comm_size=comm_size,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir,
|
|
use_float16=use_float16)
|
|
|
|
def _test_allgather(self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
blob_size=None,
|
|
num_blobs=None,
|
|
tmpdir=None,
|
|
use_float16=False
|
|
):
|
|
store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank,
|
|
comm_size=comm_size,
|
|
tmpdir=tmpdir)
|
|
|
|
blob_size = self.synchronize(
|
|
store_handler,
|
|
blob_size,
|
|
comm_rank=comm_rank)
|
|
|
|
num_blobs = self.synchronize(
|
|
store_handler,
|
|
num_blobs,
|
|
comm_rank=comm_rank)
|
|
|
|
blobs = []
|
|
for i in range(num_blobs):
|
|
blob = "blob_{}".format(i)
|
|
value = np.full(blob_size, (comm_rank * num_blobs) + i,
|
|
np.float16 if use_float16 else np.float32)
|
|
workspace.FeedBlob(blob, value)
|
|
blobs.append(blob)
|
|
|
|
net = core.Net("allgather")
|
|
net.Allgather(
|
|
[common_world] + blobs,
|
|
["Gathered"],
|
|
engine=op_engine)
|
|
|
|
workspace.CreateNet(net)
|
|
workspace.RunNet(net.Name())
|
|
# create expected output
|
|
expected_output = np.array([])
|
|
for i in range(comm_size):
|
|
for j in range(num_blobs):
|
|
value = np.full(blob_size, (i * num_blobs) + j,
|
|
np.float16 if use_float16 else np.float32)
|
|
expected_output = np.concatenate((expected_output, value))
|
|
np.testing.assert_array_equal(
|
|
workspace.FetchBlob("Gathered"), expected_output)
|
|
|
|
# Run the net a few more times to check the operator
|
|
# works not just the first time it's called
|
|
for _tmp in range(4):
|
|
workspace.RunNet(net.Name())
|
|
|
|
@given(comm_size=st.integers(min_value=2, max_value=8),
|
|
blob_size=st.integers(min_value=int(1e3), max_value=int(1e6)),
|
|
num_blobs=st.integers(min_value=1, max_value=4),
|
|
device_option=st.sampled_from([hu.cpu_do]),
|
|
use_float16=st.booleans())
|
|
@settings(max_examples=10, deadline=None)
|
|
def test_allgather(self, comm_size, blob_size, num_blobs, device_option,
|
|
use_float16):
|
|
TestCase.test_counter += 1
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_allgather,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
use_float16=use_float16,
|
|
device_option=device_option)
|
|
else:
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_allgather,
|
|
comm_size=comm_size,
|
|
blob_size=blob_size,
|
|
num_blobs=num_blobs,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir,
|
|
use_float16=use_float16)
|
|
|
|
@given(device_option=st.sampled_from([hu.cpu_do]))
|
|
@settings(deadline=10000)
|
|
def test_forked_cw(self, device_option):
|
|
TestCase.test_counter += 1
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_allreduce_multicw,
|
|
device_option=device_option)
|
|
else:
|
|
# Note: this test exercises the path where we fork a common world.
|
|
# We therefore don't need a comm size larger than 2. It used to be
|
|
# run with comm_size=8, which causes flaky results in a stress run.
|
|
# The flakiness was caused by too many listening sockets being
|
|
# created by Gloo context initialization (8 processes times
|
|
# 7 sockets times 20-way concurrency, plus TIME_WAIT).
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_allreduce_multicw,
|
|
comm_size=2,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir)
|
|
|
|
def _test_barrier(
|
|
self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
tmpdir=None,
|
|
):
|
|
store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank, comm_size=comm_size, tmpdir=tmpdir
|
|
)
|
|
|
|
net = core.Net("barrier")
|
|
net.Barrier(
|
|
[common_world],
|
|
[],
|
|
engine=op_engine)
|
|
|
|
workspace.CreateNet(net)
|
|
workspace.RunNet(net.Name())
|
|
|
|
# Run the net a few more times to check the operator
|
|
# works not just the first time it's called
|
|
for _tmp in range(4):
|
|
workspace.RunNet(net.Name())
|
|
|
|
@given(comm_size=st.integers(min_value=2, max_value=8),
|
|
device_option=st.sampled_from([hu.cpu_do]))
|
|
@settings(deadline=10000)
|
|
def test_barrier(self, comm_size, device_option):
|
|
TestCase.test_counter += 1
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_barrier,
|
|
device_option=device_option)
|
|
else:
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_barrier,
|
|
comm_size=comm_size,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir)
|
|
|
|
def _test_close_connection(
|
|
self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
tmpdir=None,
|
|
):
|
|
'''
|
|
One node calls close connection, others wait it on barrier.
|
|
Test will check that all will exit eventually.
|
|
'''
|
|
# Caffe's for closers only:
|
|
# https://www.youtube.com/watch?v=QMFwFgG9NE8
|
|
closer = comm_rank == comm_size // 2,
|
|
|
|
store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank, comm_size=comm_size, tmpdir=tmpdir
|
|
)
|
|
|
|
net = core.Net("barrier_or_close")
|
|
if not closer:
|
|
net.Barrier(
|
|
[common_world],
|
|
[],
|
|
engine=op_engine)
|
|
else:
|
|
net.DestroyCommonWorld(
|
|
[common_world], [common_world], engine=op_engine)
|
|
# Sleep a bit to ensure others start the barrier
|
|
import time
|
|
time.sleep(0.1)
|
|
|
|
workspace.CreateNet(net)
|
|
workspace.RunNet(net.Name())
|
|
|
|
@given(comm_size=st.integers(min_value=2, max_value=8),
|
|
device_option=st.sampled_from([hu.cpu_do]))
|
|
@settings(deadline=10000)
|
|
def test_close_connection(self, comm_size, device_option):
|
|
import time
|
|
start_time = time.time()
|
|
TestCase.test_counter += 1
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_close_connection,
|
|
device_option=device_option)
|
|
else:
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_close_connection,
|
|
comm_size=comm_size,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir)
|
|
# Check that test finishes quickly because connections get closed.
|
|
# This assert used to check that the end to end runtime was less
|
|
# than 2 seconds, but this may not always be the case if there
|
|
# is significant overhead in starting processes. Ideally, this
|
|
# assert is replaced by one that doesn't depend on time but rather
|
|
# checks the success/failure status of the barrier that is run.
|
|
self.assertLess(time.time() - start_time, 20.0)
|
|
|
|
def _test_io_error(
|
|
self,
|
|
comm_rank=None,
|
|
comm_size=None,
|
|
tmpdir=None,
|
|
):
|
|
'''
|
|
Only one node will participate in allreduce, resulting in an IoError
|
|
'''
|
|
store_handler, common_world = self.create_common_world(
|
|
comm_rank=comm_rank,
|
|
comm_size=comm_size,
|
|
tmpdir=tmpdir)
|
|
|
|
if comm_rank == 0:
|
|
blob_size = 1000
|
|
num_blobs = 1
|
|
|
|
blobs = []
|
|
for i in range(num_blobs):
|
|
blob = "blob_{}".format(i)
|
|
value = np.full(
|
|
blob_size, (comm_rank * num_blobs) + i, np.float32
|
|
)
|
|
workspace.FeedBlob(blob, value)
|
|
blobs.append(blob)
|
|
|
|
net = core.Net("allreduce")
|
|
net.Allreduce(
|
|
[common_world] + blobs,
|
|
blobs,
|
|
engine=op_engine)
|
|
|
|
workspace.CreateNet(net)
|
|
workspace.RunNet(net.Name())
|
|
|
|
@given(comm_size=st.integers(min_value=2, max_value=8),
|
|
device_option=st.sampled_from([hu.cpu_do]))
|
|
@settings(deadline=10000)
|
|
def test_io_error(self, comm_size, device_option):
|
|
TestCase.test_counter += 1
|
|
with self.assertRaises(IoError):
|
|
if os.getenv('COMM_RANK') is not None:
|
|
self.run_test_distributed(
|
|
self._test_io_error,
|
|
device_option=device_option)
|
|
else:
|
|
with TemporaryDirectory() as tmpdir:
|
|
self.run_test_locally(
|
|
self._test_io_error,
|
|
comm_size=comm_size,
|
|
device_option=device_option,
|
|
tmpdir=tmpdir)
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|