mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53172 Pull Request resolved: https://github.com/pytorch/elastic/pull/141 Upstreams two modules to torch: 1. `torchelastic.rendezvous` 2. `torchelastic.utils` These modules were chosen as `[1/n]` since they are the leaf modules in torchelastic. ==== NOTES: ==== 1. I'm disabling etcd_rendezvous and etcd_server tests in CIRCLECI for the moment since I need to edit the test dockers to contain the etcd server binary (there's 4-5 test dockers - one for each platform so this is going to take some time for me to set up the environments and test) - T85992919. 2. I've fixed all lint errors on python files but there are ones on the cpp files on the ZeusRendezvous. I took a look at them, and I don't want to fix the linter errors right now for 2 major reasons: 1. Some of them are more than formatting changes (e.g. std::move vs pass by value) and I don't want to introduce bundled changes with the move 1. The old rendezvous code (the one we forked from in caffe2/fb) has the same problems and I think its better for us to deal with this when we deprecate caffe2/fb/rendezvous in favor of the one in torchelastic -T86012579. Test Plan: ``` buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/data/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/fb/... buck test mode/dev-nosan //pytorch/elastic/torchelastic/... ``` \+ Sandcastle Reviewed By: H-Huang Differential Revision: D26718746 fbshipit-source-id: 67cc0350c3d847221cb3c3038f98f47915362f51
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch.distributed.elastic.utils.store as store_util
|
|
from torch.distributed.elastic.utils.logging import get_logger
|
|
|
|
|
|
class TestStore:
|
|
def get(self, key: str):
|
|
return f"retrieved:{key}"
|
|
|
|
|
|
class StoreUtilTest(unittest.TestCase):
|
|
def test_get_data(self):
|
|
store = TestStore()
|
|
data = store_util.get_all(store, "test/store", 10)
|
|
for idx in range(0, 10):
|
|
self.assertEqual(f"retrieved:test/store{idx}", data[idx])
|
|
|
|
def test_synchronize(self):
|
|
class DummyStore:
|
|
def __init__(self):
|
|
self._data = {
|
|
"torchelastic/test0": "data0".encode(encoding="UTF-8"),
|
|
"torchelastic/test1": "data1".encode(encoding="UTF-8"),
|
|
"torchelastic/test2": "data2".encode(encoding="UTF-8"),
|
|
}
|
|
|
|
def set(self, key, value):
|
|
self._data[key] = value
|
|
|
|
def get(self, key):
|
|
return self._data[key]
|
|
|
|
def set_timeout(self, timeout):
|
|
pass
|
|
|
|
data = "data0".encode(encoding="UTF-8")
|
|
store = DummyStore()
|
|
res = store_util.synchronize(store, data, 0, 3, key_prefix="torchelastic/test")
|
|
self.assertEqual(3, len(res))
|
|
for idx, res_data in enumerate(res):
|
|
actual_str = res_data.decode(encoding="UTF-8")
|
|
self.assertEqual(f"data{idx}", actual_str)
|
|
|
|
|
|
class UtilTest(unittest.TestCase):
|
|
def test_get_logger_different(self):
|
|
logger1 = get_logger("name1")
|
|
logger2 = get_logger("name2")
|
|
self.assertNotEqual(logger1.name, logger2.name)
|
|
|
|
def test_get_logger(self):
|
|
logger1 = get_logger()
|
|
self.assertEqual(__name__, logger1.name)
|
|
|
|
def test_get_logger_none(self):
|
|
logger1 = get_logger(None)
|
|
self.assertEqual(__name__, logger1.name)
|
|
|
|
def test_get_logger_custom_name(self):
|
|
logger1 = get_logger("test.module")
|
|
self.assertEqual("test.module", logger1.name)
|