Files
pytorch/test/distributed/elastic/utils/util_test.py
Kiuk Chung ba75cedfc5 [1/n][torch/elastic][upstream] Move torchelastic/rendezvous to torch/distributed/rendezvous (#53172)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53172

Pull Request resolved: https://github.com/pytorch/elastic/pull/141

Upstreams two modules to torch:

1. `torchelastic.rendezvous`
2. `torchelastic.utils`

These modules were chosen as `[1/n]` since they are the leaf modules in torchelastic.

==== NOTES: ====
1. I'm disabling etcd_rendezvous and etcd_server tests in CIRCLECI for the moment since I need to edit the test dockers to contain the etcd server binary (there's 4-5 test dockers - one for each platform so this is going to take some time for me to set up the environments and test) - T85992919.

2. I've fixed all lint errors on python files but there are ones on the cpp files on the ZeusRendezvous. I took a look at them, and I don't want to fix the linter errors right now for 2 major reasons:
     1. Some of them are more than formatting changes (e.g. std::move vs pass by value) and I don't want to introduce bundled changes with the move
     1. The old rendezvous code (the one we forked from in caffe2/fb) has the same problems and I think its better for us to deal with this when we deprecate caffe2/fb/rendezvous in favor of the one in torchelastic -T86012579.

Test Plan:
```
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/data/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/fb/...
buck test mode/dev-nosan //pytorch/elastic/torchelastic/...
```
\+ Sandcastle

Reviewed By: H-Huang

Differential Revision: D26718746

fbshipit-source-id: 67cc0350c3d847221cb3c3038f98f47915362f51
2021-03-05 11:27:57 -08:00

71 lines
2.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.utils.logging import get_logger
class TestStore:
def get(self, key: str):
return f"retrieved:{key}"
class StoreUtilTest(unittest.TestCase):
def test_get_data(self):
store = TestStore()
data = store_util.get_all(store, "test/store", 10)
for idx in range(0, 10):
self.assertEqual(f"retrieved:test/store{idx}", data[idx])
def test_synchronize(self):
class DummyStore:
def __init__(self):
self._data = {
"torchelastic/test0": "data0".encode(encoding="UTF-8"),
"torchelastic/test1": "data1".encode(encoding="UTF-8"),
"torchelastic/test2": "data2".encode(encoding="UTF-8"),
}
def set(self, key, value):
self._data[key] = value
def get(self, key):
return self._data[key]
def set_timeout(self, timeout):
pass
data = "data0".encode(encoding="UTF-8")
store = DummyStore()
res = store_util.synchronize(store, data, 0, 3, key_prefix="torchelastic/test")
self.assertEqual(3, len(res))
for idx, res_data in enumerate(res):
actual_str = res_data.decode(encoding="UTF-8")
self.assertEqual(f"data{idx}", actual_str)
class UtilTest(unittest.TestCase):
def test_get_logger_different(self):
logger1 = get_logger("name1")
logger2 = get_logger("name2")
self.assertNotEqual(logger1.name, logger2.name)
def test_get_logger(self):
logger1 = get_logger()
self.assertEqual(__name__, logger1.name)
def test_get_logger_none(self):
logger1 = get_logger(None)
self.assertEqual(__name__, logger1.name)
def test_get_logger_custom_name(self):
logger1 = get_logger("test.module")
self.assertEqual("test.module", logger1.name)