mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Refactor c10d and dist aliases for torch.distributed (#59456)
Summary: **Overview:** This consolidates `c10d` and `dist` to only `dist` as the alias for `torch.distributed` in `test_store.py`. Both aliases were used most likely due to incremental additions to the test file and not intentional. Pull Request resolved: https://github.com/pytorch/pytorch/pull/59456 Test Plan: ``` python test/distributed/test_store.py ``` Reviewed By: agolynski Differential Revision: D28910169 Pulled By: andwgu fbshipit-source-id: f830dead29e9de48aaf2845dfa5861c9cccec15d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1183fa3817
commit
bbf7eceaf0
@ -9,13 +9,12 @@ from datetime import timedelta
|
|||||||
from sys import platform
|
from sys import platform
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as c10d
|
import torch.distributed as dist
|
||||||
|
|
||||||
if not c10d.is_available():
|
if not dist.is_available():
|
||||||
print("c10d not available, skipping tests", file=sys.stderr)
|
print("torch.distributed not available, skipping tests", file=sys.stderr)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.testing._internal.common_utils as common
|
import torch.testing._internal.common_utils as common
|
||||||
from torch._six import string_classes
|
from torch._six import string_classes
|
||||||
@ -125,7 +124,7 @@ class FileStoreTest(TestCase, StoreTestBase):
|
|||||||
self.file = tempfile.NamedTemporaryFile(delete=False)
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
|
||||||
def _create_store(self):
|
def _create_store(self):
|
||||||
store = c10d.FileStore(self.file.name, 1)
|
store = dist.FileStore(self.file.name, 1)
|
||||||
store.set_timeout(timedelta(seconds=300))
|
store.set_timeout(timedelta(seconds=300))
|
||||||
return store
|
return store
|
||||||
|
|
||||||
@ -136,7 +135,7 @@ class HashStoreTest(TestCase, StoreTestBase):
|
|||||||
super(HashStoreTest, self).setUp()
|
super(HashStoreTest, self).setUp()
|
||||||
|
|
||||||
def _create_store(self):
|
def _create_store(self):
|
||||||
store = c10d.HashStore()
|
store = dist.HashStore()
|
||||||
store.set_timeout(timedelta(seconds=300))
|
store.set_timeout(timedelta(seconds=300))
|
||||||
return store
|
return store
|
||||||
|
|
||||||
@ -145,12 +144,12 @@ class PrefixFileStoreTest(TestCase, StoreTestBase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(PrefixFileStoreTest, self).setUp()
|
super(PrefixFileStoreTest, self).setUp()
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False)
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
||||||
self.filestore = c10d.FileStore(self.file.name, 1)
|
self.filestore = dist.FileStore(self.file.name, 1)
|
||||||
self.prefix = "test_prefix"
|
self.prefix = "test_prefix"
|
||||||
self.filestore.set_timeout(timedelta(seconds=300))
|
self.filestore.set_timeout(timedelta(seconds=300))
|
||||||
|
|
||||||
def _create_store(self):
|
def _create_store(self):
|
||||||
return c10d.PrefixStore(self.prefix, self.filestore)
|
return dist.PrefixStore(self.prefix, self.filestore)
|
||||||
|
|
||||||
|
|
||||||
class TCPStoreTest(TestCase, StoreTestBase):
|
class TCPStoreTest(TestCase, StoreTestBase):
|
||||||
@ -171,8 +170,8 @@ class TCPStoreTest(TestCase, StoreTestBase):
|
|||||||
# Use noqa to silence flake8.
|
# Use noqa to silence flake8.
|
||||||
# Need to store in an unused variable here to ensure the first
|
# Need to store in an unused variable here to ensure the first
|
||||||
# object is not destroyed before the second object is created.
|
# object is not destroyed before the second object is created.
|
||||||
store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
|
store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841
|
||||||
store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
|
store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841
|
||||||
|
|
||||||
# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
|
# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
|
||||||
# the user and one additional key used for coordinate all the workers.
|
# the user and one additional key used for coordinate all the workers.
|
||||||
@ -259,7 +258,7 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
|||||||
self.tcpstore.set_timeout(timedelta(seconds=300))
|
self.tcpstore.set_timeout(timedelta(seconds=300))
|
||||||
|
|
||||||
def _create_store(self):
|
def _create_store(self):
|
||||||
return c10d.PrefixStore(self.prefix, self.tcpstore)
|
return dist.PrefixStore(self.prefix, self.tcpstore)
|
||||||
|
|
||||||
# The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
|
# The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
|
||||||
# added by the user and one additional key used for coordinate all the
|
# added by the user and one additional key used for coordinate all the
|
||||||
@ -269,7 +268,7 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
|||||||
return 6
|
return 6
|
||||||
|
|
||||||
|
|
||||||
class MyPythonStore(c10d.Store):
|
class MyPythonStore(dist.Store):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyPythonStore, self).__init__()
|
super(MyPythonStore, self).__init__()
|
||||||
self.store = dict()
|
self.store = dict()
|
||||||
@ -305,13 +304,13 @@ class PythonStoreTest(TestCase):
|
|||||||
# equivalent of StoreTestBase.test_set_get from C++.
|
# equivalent of StoreTestBase.test_set_get from C++.
|
||||||
# See `torch/csrc/distributed/c10d/init.cpp` for the definition
|
# See `torch/csrc/distributed/c10d/init.cpp` for the definition
|
||||||
# of this test function.
|
# of this test function.
|
||||||
c10d._test_python_store(MyPythonStore())
|
dist._test_python_store(MyPythonStore())
|
||||||
|
|
||||||
|
|
||||||
class RendezvousTest(TestCase):
|
class RendezvousTest(TestCase):
|
||||||
def test_unknown_handler(self):
|
def test_unknown_handler(self):
|
||||||
with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
|
with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
|
||||||
c10d.rendezvous("invalid://")
|
dist.rendezvous("invalid://")
|
||||||
|
|
||||||
|
|
||||||
class RendezvousEnvTest(TestCase):
|
class RendezvousEnvTest(TestCase):
|
||||||
@ -323,7 +322,7 @@ class RendezvousEnvTest(TestCase):
|
|||||||
|
|
||||||
# Single rank
|
# Single rank
|
||||||
os.environ["RANK"] = "0"
|
os.environ["RANK"] = "0"
|
||||||
gen0 = c10d.rendezvous("env://")
|
gen0 = dist.rendezvous("env://")
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, rank0, size0 = next(gen0)
|
||||||
self.assertEqual(0, rank0)
|
self.assertEqual(0, rank0)
|
||||||
self.assertEqual(1, size0)
|
self.assertEqual(1, size0)
|
||||||
@ -337,23 +336,23 @@ class RendezvousEnvTest(TestCase):
|
|||||||
class RendezvousFileTest(TestCase):
|
class RendezvousFileTest(TestCase):
|
||||||
def test_common_errors(self):
|
def test_common_errors(self):
|
||||||
with self.assertRaisesRegex(ValueError, "path missing"):
|
with self.assertRaisesRegex(ValueError, "path missing"):
|
||||||
gen = c10d.rendezvous("file://?rank=0&world_size=1")
|
gen = dist.rendezvous("file://?rank=0&world_size=1")
|
||||||
next(gen)
|
next(gen)
|
||||||
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
||||||
gen = c10d.rendezvous("file:///tmp/foo?world_size=1")
|
gen = dist.rendezvous("file:///tmp/foo?world_size=1")
|
||||||
next(gen)
|
next(gen)
|
||||||
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
||||||
gen = c10d.rendezvous("file:///tmp/foo?rank=0")
|
gen = dist.rendezvous("file:///tmp/foo?rank=0")
|
||||||
next(gen)
|
next(gen)
|
||||||
|
|
||||||
def test_nominal(self):
|
def test_nominal(self):
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as file:
|
with tempfile.NamedTemporaryFile(delete=False) as file:
|
||||||
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
|
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
|
||||||
gen0 = c10d.rendezvous(url + "&rank=0")
|
gen0 = dist.rendezvous(url + "&rank=0")
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, rank0, size0 = next(gen0)
|
||||||
self.assertEqual(0, rank0)
|
self.assertEqual(0, rank0)
|
||||||
self.assertEqual(2, size0)
|
self.assertEqual(2, size0)
|
||||||
gen1 = c10d.rendezvous(url + "&rank=1")
|
gen1 = dist.rendezvous(url + "&rank=1")
|
||||||
store1, rank1, size1 = next(gen1)
|
store1, rank1, size1 = next(gen1)
|
||||||
self.assertEqual(1, rank1)
|
self.assertEqual(1, rank1)
|
||||||
self.assertEqual(2, size1)
|
self.assertEqual(2, size1)
|
||||||
@ -377,19 +376,19 @@ class RendezvousTCPTest(TestCase):
|
|||||||
|
|
||||||
def test_common_errors(self):
|
def test_common_errors(self):
|
||||||
with self.assertRaisesRegex(ValueError, "port number missing"):
|
with self.assertRaisesRegex(ValueError, "port number missing"):
|
||||||
gen = c10d.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
|
gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
|
||||||
next(gen)
|
next(gen)
|
||||||
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
||||||
gen = c10d.rendezvous("tcp://127.0.0.1:23456?world_size=1")
|
gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
|
||||||
next(gen)
|
next(gen)
|
||||||
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
||||||
gen = c10d.rendezvous("tcp://127.0.0.1:23456?rank=0")
|
gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
|
||||||
next(gen)
|
next(gen)
|
||||||
|
|
||||||
@retry_on_connect_failures
|
@retry_on_connect_failures
|
||||||
def test_nominal(self):
|
def test_nominal(self):
|
||||||
url = self.create_tcp_url()
|
url = self.create_tcp_url()
|
||||||
gen0 = c10d.rendezvous(url + "&rank=0")
|
gen0 = dist.rendezvous(url + "&rank=0")
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, rank0, size0 = next(gen0)
|
||||||
self.assertEqual(0, rank0)
|
self.assertEqual(0, rank0)
|
||||||
self.assertEqual(1, size0)
|
self.assertEqual(1, size0)
|
||||||
@ -404,7 +403,7 @@ class RendezvousTCPTest(TestCase):
|
|||||||
def test_tcp_store_timeout_set(self):
|
def test_tcp_store_timeout_set(self):
|
||||||
url = self.create_tcp_url()
|
url = self.create_tcp_url()
|
||||||
test_store_timeout = timedelta(seconds=10)
|
test_store_timeout = timedelta(seconds=10)
|
||||||
gen0 = c10d.rendezvous(url + "&rank=0", timeout=test_store_timeout)
|
gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, rank0, size0 = next(gen0)
|
||||||
# this should time out in 10s. If the timeout passed into rendezvous was
|
# this should time out in 10s. If the timeout passed into rendezvous was
|
||||||
# not respected, it will take much longer to timeout.
|
# not respected, it will take much longer to timeout.
|
||||||
|
Reference in New Issue
Block a user