mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Refactor c10d and dist aliases for torch.distributed (#59456)
Summary: **Overview:** This consolidates `c10d` and `dist` to only `dist` as the alias for `torch.distributed` in `test_store.py`. Both aliases were used most likely due to incremental additions to the test file and not intentional. Pull Request resolved: https://github.com/pytorch/pytorch/pull/59456 Test Plan: ``` python test/distributed/test_store.py ``` Reviewed By: agolynski Differential Revision: D28910169 Pulled By: andwgu fbshipit-source-id: f830dead29e9de48aaf2845dfa5861c9cccec15d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1183fa3817
commit
bbf7eceaf0
@ -9,13 +9,12 @@ from datetime import timedelta
|
||||
from sys import platform
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed as dist
|
||||
|
||||
if not c10d.is_available():
|
||||
print("c10d not available, skipping tests", file=sys.stderr)
|
||||
if not dist.is_available():
|
||||
print("torch.distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch._six import string_classes
|
||||
@ -125,7 +124,7 @@ class FileStoreTest(TestCase, StoreTestBase):
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False)
|
||||
|
||||
def _create_store(self):
|
||||
store = c10d.FileStore(self.file.name, 1)
|
||||
store = dist.FileStore(self.file.name, 1)
|
||||
store.set_timeout(timedelta(seconds=300))
|
||||
return store
|
||||
|
||||
@ -136,7 +135,7 @@ class HashStoreTest(TestCase, StoreTestBase):
|
||||
super(HashStoreTest, self).setUp()
|
||||
|
||||
def _create_store(self):
|
||||
store = c10d.HashStore()
|
||||
store = dist.HashStore()
|
||||
store.set_timeout(timedelta(seconds=300))
|
||||
return store
|
||||
|
||||
@ -145,12 +144,12 @@ class PrefixFileStoreTest(TestCase, StoreTestBase):
|
||||
def setUp(self):
|
||||
super(PrefixFileStoreTest, self).setUp()
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False)
|
||||
self.filestore = c10d.FileStore(self.file.name, 1)
|
||||
self.filestore = dist.FileStore(self.file.name, 1)
|
||||
self.prefix = "test_prefix"
|
||||
self.filestore.set_timeout(timedelta(seconds=300))
|
||||
|
||||
def _create_store(self):
|
||||
return c10d.PrefixStore(self.prefix, self.filestore)
|
||||
return dist.PrefixStore(self.prefix, self.filestore)
|
||||
|
||||
|
||||
class TCPStoreTest(TestCase, StoreTestBase):
|
||||
@ -171,8 +170,8 @@ class TCPStoreTest(TestCase, StoreTestBase):
|
||||
# Use noqa to silence flake8.
|
||||
# Need to store in an unused variable here to ensure the first
|
||||
# object is not destroyed before the second object is created.
|
||||
store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
|
||||
store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
|
||||
store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841
|
||||
store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841
|
||||
|
||||
# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
|
||||
# the user and one additional key used for coordinate all the workers.
|
||||
@ -259,7 +258,7 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
||||
self.tcpstore.set_timeout(timedelta(seconds=300))
|
||||
|
||||
def _create_store(self):
|
||||
return c10d.PrefixStore(self.prefix, self.tcpstore)
|
||||
return dist.PrefixStore(self.prefix, self.tcpstore)
|
||||
|
||||
# The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
|
||||
# added by the user and one additional key used for coordinate all the
|
||||
@ -269,7 +268,7 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
||||
return 6
|
||||
|
||||
|
||||
class MyPythonStore(c10d.Store):
|
||||
class MyPythonStore(dist.Store):
|
||||
def __init__(self):
|
||||
super(MyPythonStore, self).__init__()
|
||||
self.store = dict()
|
||||
@ -305,13 +304,13 @@ class PythonStoreTest(TestCase):
|
||||
# equivalent of StoreTestBase.test_set_get from C++.
|
||||
# See `torch/csrc/distributed/c10d/init.cpp` for the definition
|
||||
# of this test function.
|
||||
c10d._test_python_store(MyPythonStore())
|
||||
dist._test_python_store(MyPythonStore())
|
||||
|
||||
|
||||
class RendezvousTest(TestCase):
|
||||
def test_unknown_handler(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
|
||||
c10d.rendezvous("invalid://")
|
||||
dist.rendezvous("invalid://")
|
||||
|
||||
|
||||
class RendezvousEnvTest(TestCase):
|
||||
@ -323,7 +322,7 @@ class RendezvousEnvTest(TestCase):
|
||||
|
||||
# Single rank
|
||||
os.environ["RANK"] = "0"
|
||||
gen0 = c10d.rendezvous("env://")
|
||||
gen0 = dist.rendezvous("env://")
|
||||
store0, rank0, size0 = next(gen0)
|
||||
self.assertEqual(0, rank0)
|
||||
self.assertEqual(1, size0)
|
||||
@ -337,23 +336,23 @@ class RendezvousEnvTest(TestCase):
|
||||
class RendezvousFileTest(TestCase):
|
||||
def test_common_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, "path missing"):
|
||||
gen = c10d.rendezvous("file://?rank=0&world_size=1")
|
||||
gen = dist.rendezvous("file://?rank=0&world_size=1")
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
||||
gen = c10d.rendezvous("file:///tmp/foo?world_size=1")
|
||||
gen = dist.rendezvous("file:///tmp/foo?world_size=1")
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
||||
gen = c10d.rendezvous("file:///tmp/foo?rank=0")
|
||||
gen = dist.rendezvous("file:///tmp/foo?rank=0")
|
||||
next(gen)
|
||||
|
||||
def test_nominal(self):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as file:
|
||||
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
|
||||
gen0 = c10d.rendezvous(url + "&rank=0")
|
||||
gen0 = dist.rendezvous(url + "&rank=0")
|
||||
store0, rank0, size0 = next(gen0)
|
||||
self.assertEqual(0, rank0)
|
||||
self.assertEqual(2, size0)
|
||||
gen1 = c10d.rendezvous(url + "&rank=1")
|
||||
gen1 = dist.rendezvous(url + "&rank=1")
|
||||
store1, rank1, size1 = next(gen1)
|
||||
self.assertEqual(1, rank1)
|
||||
self.assertEqual(2, size1)
|
||||
@ -377,19 +376,19 @@ class RendezvousTCPTest(TestCase):
|
||||
|
||||
def test_common_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, "port number missing"):
|
||||
gen = c10d.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
|
||||
gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, "rank parameter missing"):
|
||||
gen = c10d.rendezvous("tcp://127.0.0.1:23456?world_size=1")
|
||||
gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, "size parameter missing"):
|
||||
gen = c10d.rendezvous("tcp://127.0.0.1:23456?rank=0")
|
||||
gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
|
||||
next(gen)
|
||||
|
||||
@retry_on_connect_failures
|
||||
def test_nominal(self):
|
||||
url = self.create_tcp_url()
|
||||
gen0 = c10d.rendezvous(url + "&rank=0")
|
||||
gen0 = dist.rendezvous(url + "&rank=0")
|
||||
store0, rank0, size0 = next(gen0)
|
||||
self.assertEqual(0, rank0)
|
||||
self.assertEqual(1, size0)
|
||||
@ -404,7 +403,7 @@ class RendezvousTCPTest(TestCase):
|
||||
def test_tcp_store_timeout_set(self):
|
||||
url = self.create_tcp_url()
|
||||
test_store_timeout = timedelta(seconds=10)
|
||||
gen0 = c10d.rendezvous(url + "&rank=0", timeout=test_store_timeout)
|
||||
gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
|
||||
store0, rank0, size0 = next(gen0)
|
||||
# this should time out in 10s. If the timeout passed into rendezvous was
|
||||
# not respected, it will take much longer to timeout.
|
||||
|
Reference in New Issue
Block a user