Refactor c10d and dist aliases for torch.distributed (#59456)

Summary:
**Overview:**
This consolidates `c10d` and `dist` to only `dist` as the alias for `torch.distributed` in `test_store.py`. Both aliases were used most likely due to incremental additions to the test file and not intentional.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59456

Test Plan:
```
python test/distributed/test_store.py
```

Reviewed By: agolynski

Differential Revision: D28910169

Pulled By: andwgu

fbshipit-source-id: f830dead29e9de48aaf2845dfa5861c9cccec15d
This commit is contained in:
Andrew Gu
2021-06-04 16:06:11 -07:00
committed by Facebook GitHub Bot
parent 1183fa3817
commit bbf7eceaf0

View File

@ -9,13 +9,12 @@ from datetime import timedelta
from sys import platform
import torch
import torch.distributed as c10d
import torch.distributed as dist
if not c10d.is_available():
print("c10d not available, skipping tests", file=sys.stderr)
if not dist.is_available():
print("torch.distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.testing._internal.common_utils as common
from torch._six import string_classes
@ -125,7 +124,7 @@ class FileStoreTest(TestCase, StoreTestBase):
self.file = tempfile.NamedTemporaryFile(delete=False)
def _create_store(self):
store = c10d.FileStore(self.file.name, 1)
store = dist.FileStore(self.file.name, 1)
store.set_timeout(timedelta(seconds=300))
return store
@ -136,7 +135,7 @@ class HashStoreTest(TestCase, StoreTestBase):
super(HashStoreTest, self).setUp()
def _create_store(self):
store = c10d.HashStore()
store = dist.HashStore()
store.set_timeout(timedelta(seconds=300))
return store
@ -145,12 +144,12 @@ class PrefixFileStoreTest(TestCase, StoreTestBase):
def setUp(self):
super(PrefixFileStoreTest, self).setUp()
self.file = tempfile.NamedTemporaryFile(delete=False)
self.filestore = c10d.FileStore(self.file.name, 1)
self.filestore = dist.FileStore(self.file.name, 1)
self.prefix = "test_prefix"
self.filestore.set_timeout(timedelta(seconds=300))
def _create_store(self):
return c10d.PrefixStore(self.prefix, self.filestore)
return dist.PrefixStore(self.prefix, self.filestore)
class TCPStoreTest(TestCase, StoreTestBase):
@ -171,8 +170,8 @@ class TCPStoreTest(TestCase, StoreTestBase):
# Use noqa to silence flake8.
# Need to store in an unused variable here to ensure the first
# object is not destroyed before the second object is created.
store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841
store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841
# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
# the user and one additional key used for coordinate all the workers.
@ -259,7 +258,7 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
self.tcpstore.set_timeout(timedelta(seconds=300))
def _create_store(self):
return c10d.PrefixStore(self.prefix, self.tcpstore)
return dist.PrefixStore(self.prefix, self.tcpstore)
# The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
# added by the user and one additional key used for coordinate all the
@ -269,7 +268,7 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
return 6
class MyPythonStore(c10d.Store):
class MyPythonStore(dist.Store):
def __init__(self):
super(MyPythonStore, self).__init__()
self.store = dict()
@ -305,13 +304,13 @@ class PythonStoreTest(TestCase):
# equivalent of StoreTestBase.test_set_get from C++.
# See `torch/csrc/distributed/c10d/init.cpp` for the definition
# of this test function.
c10d._test_python_store(MyPythonStore())
dist._test_python_store(MyPythonStore())
class RendezvousTest(TestCase):
def test_unknown_handler(self):
with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
c10d.rendezvous("invalid://")
dist.rendezvous("invalid://")
class RendezvousEnvTest(TestCase):
@ -323,7 +322,7 @@ class RendezvousEnvTest(TestCase):
# Single rank
os.environ["RANK"] = "0"
gen0 = c10d.rendezvous("env://")
gen0 = dist.rendezvous("env://")
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
self.assertEqual(1, size0)
@ -337,23 +336,23 @@ class RendezvousEnvTest(TestCase):
class RendezvousFileTest(TestCase):
def test_common_errors(self):
with self.assertRaisesRegex(ValueError, "path missing"):
gen = c10d.rendezvous("file://?rank=0&world_size=1")
gen = dist.rendezvous("file://?rank=0&world_size=1")
next(gen)
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
gen = c10d.rendezvous("file:///tmp/foo?world_size=1")
gen = dist.rendezvous("file:///tmp/foo?world_size=1")
next(gen)
with self.assertRaisesRegex(ValueError, "size parameter missing"):
gen = c10d.rendezvous("file:///tmp/foo?rank=0")
gen = dist.rendezvous("file:///tmp/foo?rank=0")
next(gen)
def test_nominal(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
gen0 = c10d.rendezvous(url + "&rank=0")
gen0 = dist.rendezvous(url + "&rank=0")
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
self.assertEqual(2, size0)
gen1 = c10d.rendezvous(url + "&rank=1")
gen1 = dist.rendezvous(url + "&rank=1")
store1, rank1, size1 = next(gen1)
self.assertEqual(1, rank1)
self.assertEqual(2, size1)
@ -377,19 +376,19 @@ class RendezvousTCPTest(TestCase):
def test_common_errors(self):
with self.assertRaisesRegex(ValueError, "port number missing"):
gen = c10d.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
next(gen)
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
gen = c10d.rendezvous("tcp://127.0.0.1:23456?world_size=1")
gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
next(gen)
with self.assertRaisesRegex(ValueError, "size parameter missing"):
gen = c10d.rendezvous("tcp://127.0.0.1:23456?rank=0")
gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
next(gen)
@retry_on_connect_failures
def test_nominal(self):
url = self.create_tcp_url()
gen0 = c10d.rendezvous(url + "&rank=0")
gen0 = dist.rendezvous(url + "&rank=0")
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
self.assertEqual(1, size0)
@ -404,7 +403,7 @@ class RendezvousTCPTest(TestCase):
def test_tcp_store_timeout_set(self):
url = self.create_tcp_url()
test_store_timeout = timedelta(seconds=10)
gen0 = c10d.rendezvous(url + "&rank=0", timeout=test_store_timeout)
gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
store0, rank0, size0 = next(gen0)
# this should time out in 10s. If the timeout passed into rendezvous was
# not respected, it will take much longer to timeout.