!235 Add test cases to templates and fusion operators
Merge pull request !235 from 孙银磊/master-add-ut
This commit is contained in:
37
tests/unit/cli/test_cli_main.py
Normal file
37
tests/unit/cli/test_cli_main.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
import unittest
|
||||
from contextlib import redirect_stdout
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from openmind.cli.cli import get_device_count
|
||||
from openmind.cli.cli import main as cli_main
|
||||
|
||||
|
||||
class TestCliMain(unittest.TestCase):
|
||||
def test_get_device_count(self):
|
||||
captured_output = StringIO()
|
||||
with redirect_stdout(captured_output):
|
||||
device_count = get_device_count()
|
||||
self.assertGreaterEqual(device_count, 0)
|
||||
self.assertLessEqual(device_count, 16)
|
||||
|
||||
@patch(
|
||||
"sys.argv",
|
||||
["openmind-cli", "none"],
|
||||
)
|
||||
def test_cli_main_failed(self):
|
||||
with self.assertRaises(ValueError):
|
||||
cli_main()
|
@ -10,21 +10,33 @@
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import patch, Mock, MagicMock
|
||||
|
||||
from openmind.flow.arguments import get_args, initialize_openmind
|
||||
from openmind.flow.datasets.template import Template
|
||||
from openmind.flow.datasets import get_template
|
||||
from openmind.flow.datasets.template import Template
|
||||
from openmind.flow.model.model_registry import ModelMetadata
|
||||
from openmind.utils.constants import Tokens
|
||||
|
||||
from tests.utils_for_test import require_torch, require_transformers, slow
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def create_tmp_dir():
|
||||
tmp_dir = tempfile.TemporaryDirectory(dir=os.path.expanduser("~"))
|
||||
os.environ["XDG_CACHE_HOME"] = tmp_dir.name
|
||||
|
||||
yield tmp_dir.name
|
||||
|
||||
os.environ.pop("XDG_CACHE_HOME", None)
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
class TestTemplate(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@ -100,6 +112,29 @@ class TestTemplate(unittest.TestCase):
|
||||
self.assertEqual(template.assistant_template, mock_template_data["assistant_template"])
|
||||
self.assertEqual(template.prefix_template, mock_template_data["prefix_template"])
|
||||
|
||||
def test_custom_template_json_failed(self):
|
||||
mock_template_data = {
|
||||
"system_template": "system_template",
|
||||
"user_template": "user_template",
|
||||
"assistant_template": "assistant_template",
|
||||
"prefix_template": ["prefix_template"],
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_str = json.dumps(mock_template_data)
|
||||
tmp_file_path = os.path.join(tmpdirname, "template.json")
|
||||
with open(tmp_file_path, "w") as tmp_file:
|
||||
tmp_file.write(json_str)
|
||||
|
||||
args = get_args()
|
||||
args.template = None
|
||||
with self.assertRaises(ValueError):
|
||||
get_template()
|
||||
|
||||
args.template = "None"
|
||||
with self.assertRaises(ValueError):
|
||||
get_template()
|
||||
|
||||
|
||||
class TestTemplateFunctions(unittest.TestCase):
|
||||
|
||||
@ -170,3 +205,73 @@ class TestTemplateFunctions(unittest.TestCase):
|
||||
pairs_2 = template._make_pairs(encoded_messages_2, cutoff_len=15, reserved_label_len=5)
|
||||
expected_2 = [([1] * 10, [2] * 5)]
|
||||
self.assertEqual(pairs_2, expected_2)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_transformers
|
||||
@patch("openmind.flow.datasets.template.logger.info_rank0")
|
||||
def test_add_or_replace_eos_token(self, mock_logger):
|
||||
with create_tmp_dir():
|
||||
from openmind import AutoTokenizer
|
||||
|
||||
template = Template()
|
||||
tokenizer = AutoTokenizer.from_pretrained("modeltesting/tiny_random_llama", use_fast=False)
|
||||
eos_token = "<|add_eos_token_0|>"
|
||||
template._add_or_replace_eos_token(tokenizer, eos_token=eos_token)
|
||||
expected_msg = "Replace eos token: <|add_eos_token_0|>."
|
||||
mock_logger.assert_called_once_with(expected_msg)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_transformers
|
||||
@patch("openmind.flow.datasets.template.get_args")
|
||||
def test_encode_oneturn(self, mock_get_args):
|
||||
from openmind import AutoTokenizer
|
||||
|
||||
template = Template(
|
||||
system_template=r"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>",
|
||||
user_template=r"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
|
||||
r"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
assistant_template=r"{content}",
|
||||
prefix_template=r"bos_token",
|
||||
)
|
||||
mock_get_args.return_value = MagicMock(cutoff_len=10, reserved_label_len=10)
|
||||
prompt = [
|
||||
{"role": "assistant", "content": "prompt for assistant"},
|
||||
{"role": "user", "content": "prompt for user"},
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("modeltesting/tiny_random_llama", use_fast=False)
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, prompt, "system_prompt")
|
||||
|
||||
expected_prompt_ids = []
|
||||
expected_answer_ids = ([], [529, 29989, 2962, 29918, 6672, 29918, 333, 29989, 29958, 1792])
|
||||
self.assertEqual(expected_prompt_ids, prompt_ids)
|
||||
self.assertEqual(expected_answer_ids, answer_ids)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_transformers
|
||||
@patch("openmind.flow.datasets.template.get_args")
|
||||
def test_encode(self, mock_get_args):
|
||||
from openmind import AutoTokenizer
|
||||
|
||||
template = Template(
|
||||
system_template=r"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>",
|
||||
user_template=r"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
|
||||
r"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
assistant_template=r"{content}",
|
||||
prefix_template=r"bos_token",
|
||||
)
|
||||
|
||||
mock_get_args.return_value = MagicMock(cutoff_len=10, reserved_label_len=10)
|
||||
prompt = [
|
||||
{"role": "assistant", "content": "prompt for assistant"},
|
||||
{"role": "user", "content": "prompt for user"},
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("modeltesting/tiny_random_llama", use_fast=False)
|
||||
encoded_message = template.encode(tokenizer, prompt, "system_prompt")
|
||||
|
||||
expected_encoded_message = [([], [529, 29989, 2962, 29918, 6672, 29918, 333, 29989, 29958, 1792])]
|
||||
self.assertEqual(expected_encoded_message, encoded_message)
|
||||
|
84
tests/unit/flow/model/test_zigzag_ring_flash_attn_func.py
Normal file
84
tests/unit/flow/model/test_zigzag_ring_flash_attn_func.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from openmind.utils import get_logger
|
||||
|
||||
from tests.utils_for_test import require_npu
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def extract_local(value, rank, world_size, dim=1):
|
||||
value_chunks = value.chunk(2 * world_size, dim=dim)
|
||||
local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim)
|
||||
return local_value.contiguous()
|
||||
|
||||
|
||||
class TestZigZagRingAttn(unittest.TestCase):
|
||||
@require_npu
|
||||
def test_zip_zag_ring_attn(self):
|
||||
import torch_npu
|
||||
|
||||
rank = 0
|
||||
world_size = 1
|
||||
seed = 0
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("npu:0")
|
||||
|
||||
batch_size = 1
|
||||
seqlen = 3824
|
||||
nheads = 5
|
||||
d = 128
|
||||
dropout_p = 0
|
||||
assert d % 8 == 0
|
||||
|
||||
q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
local_q = extract_local(q, rank, world_size).detach().clone()
|
||||
local_k = extract_local(k, rank, world_size).detach().clone()
|
||||
local_v = extract_local(v, rank, world_size).detach().clone()
|
||||
local_q.requires_grad = True
|
||||
local_k.requires_grad = True
|
||||
local_v.requires_grad = True
|
||||
|
||||
attn_mask = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
out, softmax_max, softmax_sum, _, _, _, _ = torch_npu.npu_fusion_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
head_num=q.shape[2],
|
||||
input_layout="BSND",
|
||||
atten_mask=attn_mask,
|
||||
scale=d ** (-0.5),
|
||||
pre_tockens=k.shape[1],
|
||||
next_tockens=0,
|
||||
sparse_mode=3,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
)
|
||||
|
||||
local_out = extract_local(out, rank, world_size)
|
||||
local_softmax_max = extract_local(softmax_max, rank, world_size, dim=2)
|
||||
local_softmax_sum = extract_local(softmax_sum, rank, world_size, dim=2)
|
||||
# softmax_max shape is [batch_size, nheads, seqlen, 8]
|
||||
assert softmax_max.shape == (batch_size, nheads, seqlen, 8)
|
||||
assert local_out.shape == (batch_size, seqlen, nheads, d)
|
||||
assert local_softmax_max.shape == (batch_size, nheads, seqlen, 8)
|
||||
assert local_softmax_sum.shape == (batch_size, nheads, seqlen, 8)
|
@ -172,6 +172,16 @@ def require_torch(test_case):
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_npu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires the vision dependencies. These tests are skipped when npu isn't
|
||||
not found.
|
||||
"""
|
||||
from accelerate.utils import is_npu_available
|
||||
|
||||
return unittest.skipUnless(is_npu_available(check_device=True), "test requires NPU")(test_case)
|
||||
|
||||
|
||||
def require_lm_eval(test_case):
|
||||
return unittest.skipUnless(is_lmeval_available(), "test requires lm_eval")(test_case)
|
||||
|
||||
|
Reference in New Issue
Block a user