mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
160
ci/access_control_test.py
Normal file
160
ci/access_control_test.py
Normal file
@ -0,0 +1,160 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_files_from_txt(txt_file):
|
||||
with open(txt_file, "r") as f:
|
||||
return [line.strip() for line in f.readlines()]
|
||||
|
||||
|
||||
def is_examples(file):
|
||||
return file.startswith("examples/") or file.startswith("tests/poc/")
|
||||
|
||||
|
||||
def is_markdown(file):
|
||||
return file.endswith(".md")
|
||||
|
||||
|
||||
def is_image(file):
|
||||
return file.endswith(".jpg") or file.endswith(".png")
|
||||
|
||||
|
||||
def is_txt(file):
|
||||
return file.endswith(".txt")
|
||||
|
||||
|
||||
def is_owners(file):
|
||||
return file.startswith("OWNERS")
|
||||
|
||||
|
||||
def is_license(file):
|
||||
return file.startswith("LICENSE")
|
||||
|
||||
|
||||
def is_ut(file):
|
||||
return file.startswith("tests/ut")
|
||||
|
||||
|
||||
def is_no_suffix(file):
|
||||
return os.path.splitext(file)[1] == ''
|
||||
|
||||
|
||||
def skip_ci(files, skip_conds):
|
||||
for file in files:
|
||||
if not any(condition(file) for condition in skip_conds):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def choose_skip_ci(raw_txt_file):
|
||||
if not os.path.exists(raw_txt_file):
|
||||
return False
|
||||
|
||||
file_list = read_files_from_txt(raw_txt_file)
|
||||
skip_conds = [
|
||||
is_examples,
|
||||
is_markdown,
|
||||
is_image,
|
||||
is_txt,
|
||||
is_owners,
|
||||
is_license,
|
||||
is_no_suffix
|
||||
]
|
||||
|
||||
return skip_ci(file_list, skip_conds)
|
||||
|
||||
|
||||
def filter_exec_ut(raw_txt_file):
|
||||
file_list = read_files_from_txt(raw_txt_file)
|
||||
filter_conds = [
|
||||
is_ut,
|
||||
is_markdown
|
||||
]
|
||||
for file in file_list:
|
||||
if not any(condition(file) for condition in filter_conds):
|
||||
return False, None
|
||||
return True, file_list
|
||||
|
||||
|
||||
def acquire_exitcode(command):
|
||||
exitcode = os.system(command)
|
||||
real_code = os.WEXITSTATUS(exitcode)
|
||||
return real_code
|
||||
|
||||
|
||||
# =============================
|
||||
# UT test, run with pytest
|
||||
# =============================
|
||||
|
||||
class UTTest:
|
||||
def __init__(self):
|
||||
self.base_dir = Path(__file__).absolute().parents[1]
|
||||
self.test_dir = os.path.join(self.base_dir, 'tests')
|
||||
self.ut_files = os.path.join(
|
||||
self.base_dir, self.test_dir, "ut"
|
||||
)
|
||||
|
||||
def run_ut(self, raw_txt_file=None):
|
||||
if raw_txt_file is not None and os.path.exists(raw_txt_file):
|
||||
filtered_results = filter_exec_ut(raw_txt_file)
|
||||
|
||||
if filtered_results[0]:
|
||||
filtered_files = filtered_results[1]
|
||||
full_path = [os.path.join(self.base_dir, file) for file in filtered_files]
|
||||
exsit_ut_files = [file for file in full_path if os.path.exists(file) and file.endswith(".py")]
|
||||
self.ut_files = " ".join(exsit_ut_files)
|
||||
|
||||
command = f"pytest -x --log-cli-level=INFO {self.ut_files}"
|
||||
code = acquire_exitcode(command)
|
||||
if code == 0:
|
||||
print("UT test success")
|
||||
else:
|
||||
print("UT failed")
|
||||
exit(1)
|
||||
|
||||
|
||||
# ===============================================
|
||||
# ST test, run with sh.
|
||||
# ===============================================
|
||||
|
||||
class STTest:
|
||||
def __init__(self):
|
||||
self.base_dir = Path(__file__).absolute().parents[1]
|
||||
self.test_dir = os.path.join(self.base_dir, 'tests')
|
||||
|
||||
self.st_dir = "st"
|
||||
self.st_shell = os.path.join(
|
||||
self.test_dir, self.st_dir, "st_run.sh"
|
||||
)
|
||||
|
||||
def run_st(self):
|
||||
rectify_case = f"bash {self.st_shell}"
|
||||
rectify_code = acquire_exitcode(rectify_case)
|
||||
if rectify_code != 0:
|
||||
print("rectify case failed, check it.")
|
||||
exit(1)
|
||||
|
||||
|
||||
def run_tests(raw_txt_file):
|
||||
ut = UTTest()
|
||||
st = STTest()
|
||||
if filter_exec_ut(raw_txt_file)[0]:
|
||||
ut.run_ut(raw_txt_file)
|
||||
else:
|
||||
ut.run_ut()
|
||||
st.run_st()
|
||||
|
||||
|
||||
def main():
|
||||
parent_dir = Path(__file__).absolute().parents[2]
|
||||
raw_txt_file = os.path.join(parent_dir, "modify.txt")
|
||||
|
||||
skip_signal = choose_skip_ci(raw_txt_file)
|
||||
if skip_signal:
|
||||
print("Skipping CI")
|
||||
else:
|
||||
run_tests(raw_txt_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
1
tests/st/st_run.sh
Normal file
1
tests/st/st_run.sh
Normal file
@ -0,0 +1 @@
|
||||
echo "Hello World !"
|
349
tests/test_tools/dist_test.py
Normal file
349
tests/test_tools/dist_test.py
Normal file
@ -0,0 +1,349 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
# This source code is licensed under the Apache license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# copied from https://github.com/microsoft/DeepSpeed/blob/master/tests/unit/common.py
|
||||
# reworked/refactored some parts to make it run.
|
||||
import os
|
||||
import time
|
||||
import inspect
|
||||
import socket
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
import pytest
|
||||
from _pytest.outcomes import Skipped
|
||||
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
|
||||
|
||||
# Worker timeout for tests that hang
|
||||
TEST_TIMEOUT = 600
|
||||
|
||||
|
||||
def get_xdist_worker_id():
|
||||
xdist_worker = os.environ.get("PYTEST_XDIST_WORKER", None)
|
||||
if xdist_worker is not None:
|
||||
xdist_worker_id = xdist_worker.replace("gw", "")
|
||||
return int(xdist_worker_id)
|
||||
return None
|
||||
|
||||
|
||||
def get_master_port(base_port=29500, port_range_size=1000):
|
||||
xdist_worker_id = get_xdist_worker_id()
|
||||
if xdist_worker_id is not None:
|
||||
# Make xdist workers use different port ranges to avoid race conditions
|
||||
base_port += port_range_size * xdist_worker_id
|
||||
|
||||
# Select first open port in range
|
||||
port = base_port
|
||||
max_port = base_port + port_range_size
|
||||
sock = socket.socket()
|
||||
while port < max_port:
|
||||
try:
|
||||
sock.bind(("", port))
|
||||
sock.close()
|
||||
return str(port)
|
||||
except OSError:
|
||||
port += 1
|
||||
raise IOError("no free ports")
|
||||
|
||||
|
||||
class DistributedExec(ABC):
|
||||
"""
|
||||
Base class for distributed execution of functions/methods. Contains common
|
||||
methods needed for DistributedTest and DistributedFixture.
|
||||
"""
|
||||
|
||||
world_size = 2
|
||||
backend = "nccl"
|
||||
init_distributed = True
|
||||
set_dist_env = True
|
||||
reuse_dist_env = False
|
||||
_pool_cache = {}
|
||||
exec_timeout = TEST_TIMEOUT
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
...
|
||||
|
||||
def __call__(self, request=None):
|
||||
self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
|
||||
world_size = self.world_size
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("only supported in accelerator environments.")
|
||||
|
||||
if isinstance(world_size, int):
|
||||
world_size = [world_size]
|
||||
for procs in world_size:
|
||||
self._launch_procs(procs)
|
||||
|
||||
def _get_fixture_kwargs(self, request, func):
|
||||
if not request:
|
||||
return {}
|
||||
# Grab fixture / parametrize kwargs from pytest request object
|
||||
fixture_kwargs = {}
|
||||
params = inspect.getfullargspec(func).args
|
||||
params.remove("self")
|
||||
for p in params:
|
||||
try:
|
||||
fixture_kwargs[p] = request.getfixturevalue(p)
|
||||
except FixtureLookupError:
|
||||
pass # test methods can have kwargs that are not fixtures
|
||||
return fixture_kwargs
|
||||
|
||||
def _launch_procs(self, num_procs):
|
||||
# Verify we have enough accelerator devices to run this test
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() < num_procs:
|
||||
pytest.skip(
|
||||
f"Skipping test because not enough GPUs are available: {num_procs} required, {torch.cuda.device_count()} available"
|
||||
)
|
||||
|
||||
# Set start method to `forkserver` (or `fork`)
|
||||
mp.set_start_method("forkserver", force=True)
|
||||
|
||||
# Create process pool or use cached one
|
||||
master_port = None
|
||||
if self.reuse_dist_env:
|
||||
if num_procs not in self._pool_cache:
|
||||
self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
|
||||
master_port = get_master_port()
|
||||
pool = self._pool_cache[num_procs]
|
||||
else:
|
||||
pool = mp.Pool(processes=num_procs)
|
||||
master_port = get_master_port()
|
||||
|
||||
# Run the test
|
||||
args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
|
||||
skip_msgs_async = pool.starmap_async(self._dist_run, args)
|
||||
|
||||
try:
|
||||
skip_msgs = skip_msgs_async.get(self.exec_timeout)
|
||||
except mp.TimeoutError:
|
||||
# Shortcut to exit pytest in the case of a hanged test. This
|
||||
# usually means an environment error and the rest of tests will
|
||||
# hang (causing super long unit test runtimes)
|
||||
pytest.exit("Test hanged, exiting", returncode=0)
|
||||
|
||||
# Tear down distributed environment and close process pools
|
||||
self._close_pool(pool, num_procs)
|
||||
|
||||
# If we skipped a test, propagate that to this process
|
||||
if any(skip_msgs):
|
||||
assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
|
||||
pytest.skip(skip_msgs[0])
|
||||
|
||||
def _dist_run(self, local_rank, num_procs, master_port):
|
||||
skip_msg = ""
|
||||
if not dist.is_initialized():
|
||||
""" Initialize torch.distributed and execute the user function. """
|
||||
if self.set_dist_env:
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
# NOTE: unit tests don't support multi-node so local_rank == global rank
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
# In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE
|
||||
# single node launcher would also set LOCAL_SIZE accordingly
|
||||
os.environ["LOCAL_SIZE"] = str(num_procs)
|
||||
os.environ["WORLD_SIZE"] = str(num_procs)
|
||||
|
||||
print(
|
||||
f"Initializing torch.distributed with rank: {local_rank}, world_size: {num_procs}"
|
||||
)
|
||||
torch.cuda.set_device(local_rank % torch.cuda.device_count())
|
||||
init_method = "tcp://"
|
||||
master_ip = os.getenv("MASTER_ADDR", "localhost")
|
||||
master_port = str(master_port)
|
||||
init_method += master_ip + ":" + master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend=self.backend,
|
||||
world_size=num_procs,
|
||||
rank=local_rank,
|
||||
init_method=init_method,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
try:
|
||||
self.run(**self._fixture_kwargs)
|
||||
except BaseException as e:
|
||||
if isinstance(e, Skipped):
|
||||
skip_msg = e.msg
|
||||
else:
|
||||
raise e
|
||||
|
||||
return skip_msg
|
||||
|
||||
def _dist_destroy(self):
|
||||
if (dist is not None) and dist.is_initialized():
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def _close_pool(self, pool, num_procs, force=False):
|
||||
if force or not self.reuse_dist_env:
|
||||
msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
class DistributedFixture(DistributedExec):
|
||||
"""
|
||||
Implementation that extends @pytest.fixture to allow for distributed execution.
|
||||
This is primarily meant to be used when a test requires executing two pieces of
|
||||
code with different world sizes.
|
||||
|
||||
There are 2 parameters that can be modified:
|
||||
- world_size: int = 2 -- the number of processes to launch
|
||||
- backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
|
||||
|
||||
Features:
|
||||
- able to call pytest.skip() inside fixture
|
||||
- can be reused by multiple tests
|
||||
- can accept other fixtures as input
|
||||
|
||||
Limitations:
|
||||
- cannot use @pytest.mark.parametrize
|
||||
- world_size cannot be modified after definition and only one world_size value is accepted
|
||||
- any fixtures used must also be used in the test that uses this fixture (see example below)
|
||||
- return values cannot be returned. Passing values to a DistributedTest
|
||||
object can be achieved using class_tmpdir and writing to file (see example below)
|
||||
|
||||
Usage:
|
||||
- must implement a run(self, ...) method
|
||||
- fixture can be used by making the class name input to a test function
|
||||
|
||||
Example:
|
||||
@pytest.fixture(params=[10,20])
|
||||
def regular_pytest_fixture(request):
|
||||
return request.param
|
||||
|
||||
class distributed_fixture_example(DistributedFixture):
|
||||
world_size = 4
|
||||
|
||||
def run(self, regular_pytest_fixture, class_tmpdir):
|
||||
assert int(os.environ["WORLD_SIZE"]) == self.world_size
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
print(f"Rank {local_rank} with value {regular_pytest_fixture}")
|
||||
with open(os.path.join(class_tmpdir, f"{local_rank}.txt"), "w") as f:
|
||||
f.write(f"{local_rank},{regular_pytest_fixture}")
|
||||
|
||||
class TestExample(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, distributed_fixture_example, regular_pytest_fixture, class_tmpdir):
|
||||
assert int(os.environ["WORLD_SIZE"]) == self.world_size
|
||||
for rank in range(4):
|
||||
with open(os.path.join(class_tmpdir, f"{rank}.txt"), "r") as f:
|
||||
assert f.read() == f"{rank},{regular_pytest_fixture}"
|
||||
"""
|
||||
|
||||
is_dist_fixture = True
|
||||
|
||||
# These values are just placeholders so that pytest recognizes this as a fixture
|
||||
_pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None)
|
||||
__name__ = ""
|
||||
|
||||
def __init__(self):
|
||||
assert isinstance(
|
||||
self.world_size, int
|
||||
), "Only one world size is allowed for distributed fixtures"
|
||||
self.__name__ = type(self).__name__
|
||||
_pytestfixturefunction = FixtureFunctionMarker(
|
||||
scope="function", params=None, name=self.__name__
|
||||
)
|
||||
|
||||
|
||||
class DistributedTest(DistributedExec):
|
||||
"""
|
||||
Implementation for running pytest with distributed execution.
|
||||
|
||||
There are 2 parameters that can be modified:
|
||||
- world_size: Union[int,List[int]] = 2 -- the number of processes to launch
|
||||
- backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
|
||||
|
||||
Features:
|
||||
- able to call pytest.skip() inside tests
|
||||
- works with pytest fixtures, parametrize, mark, etc.
|
||||
- can contain multiple tests (each of which can be parametrized separately)
|
||||
- class methods can be fixtures (usable by tests in this class only)
|
||||
- world_size can be changed for individual tests using @pytest.mark.world_size(world_size)
|
||||
- class_tmpdir is a fixture that can be used to get a tmpdir shared among
|
||||
all tests (including DistributedFixture)
|
||||
|
||||
Usage:
|
||||
- class name must start with "Test"
|
||||
- must implement one or more test*(self, ...) methods
|
||||
|
||||
Example:
|
||||
@pytest.fixture(params=[10,20])
|
||||
def val1(request):
|
||||
return request.param
|
||||
|
||||
@pytest.mark.fast
|
||||
@pytest.mark.parametrize("val2", [30,40])
|
||||
class TestExample(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
@pytest.fixture(params=[50,60])
|
||||
def val3(self, request):
|
||||
return request.param
|
||||
|
||||
def test_1(self, val1, val2, str1="hello world"):
|
||||
assert int(os.environ["WORLD_SIZE"]) == self.world_size
|
||||
assert all(val1, val2, str1)
|
||||
|
||||
@pytest.mark.world_size(1)
|
||||
@pytest.mark.parametrize("val4", [70,80])
|
||||
def test_2(self, val1, val2, val3, val4):
|
||||
assert int(os.environ["WORLD_SIZE"]) == 1
|
||||
assert all(val1, val2, val3, val4)
|
||||
"""
|
||||
|
||||
is_dist_test = True
|
||||
|
||||
# Temporary directory that is shared among test methods in a class
|
||||
@pytest.fixture(autouse=True, scope="class")
|
||||
def class_tmpdir(self, tmpdir_factory):
|
||||
fn = tmpdir_factory.mktemp(self.__class__.__name__)
|
||||
return fn
|
||||
|
||||
def run(self, **fixture_kwargs):
|
||||
self._current_test(**fixture_kwargs)
|
||||
|
||||
def __call__(self, request):
|
||||
self._current_test = self._get_current_test_func(request)
|
||||
self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("only supported in accelerator environments.")
|
||||
|
||||
# Catch world_size override pytest mark
|
||||
for mark in getattr(request.function, "pytestmark", []):
|
||||
if mark.name == "world_size":
|
||||
world_size = mark.args[0]
|
||||
break
|
||||
else:
|
||||
world_size = self.world_size
|
||||
|
||||
if isinstance(world_size, int):
|
||||
world_size = [world_size]
|
||||
for procs in world_size:
|
||||
self._launch_procs(procs)
|
||||
time.sleep(0.5)
|
||||
|
||||
def _get_current_test_func(self, request):
|
||||
# DistributedTest subclasses may have multiple test methods
|
||||
func_name = request.function.__name__
|
||||
return getattr(self, func_name)
|
||||
|
||||
|
||||
def create_testconfig(path: str):
|
||||
with open(path) as f:
|
||||
raw_data = json.load(f)
|
||||
|
||||
return {k: [tuple(s.values()) if len(s) > 1 else tuple(s.values())[0] for s in v] for k, v in raw_data.items()}
|
0
tests/ut/__init__.py
Normal file
0
tests/ut/__init__.py
Normal file
25
tests/ut/mock/test_mock.py
Normal file
25
tests/ut/mock/test_mock.py
Normal file
@ -0,0 +1,25 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Just an initialize test"""
|
||||
|
||||
import pytest # Just try can import or not
|
||||
from tests.test_tools.dist_test import DistributedTest
|
||||
|
||||
|
||||
class TestMock(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_mock_op(self):
|
||||
assert 1 + 1 == 2, "Failed !"
|
Reference in New Issue
Block a user