@ -61,7 +61,6 @@ def convert_pairwise(examples, datasets_attr: InstructionDatasetAttr):
|
||||
prompt.append({"role": "user", "content": "\n".join(content)})
|
||||
|
||||
if examples[datasets_attr.chosen][i] and examples[datasets_attr.rejected][i]:
|
||||
# response.append([examples[datasets_attr.chosen][i], examples[datasets_attr.rejected][i]])
|
||||
response.append(
|
||||
[
|
||||
{"role": "assistant", "content": examples[datasets_attr.chosen][i]},
|
||||
|
@ -23,14 +23,14 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Union, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from openmind.flow.model.model_registry import SUPPORTED_MODELS
|
||||
from openmind.utils.constants import Tokens
|
||||
from openmind.utils import get_logger
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.datasets.mm_plugin import BasePlugin, parse_mm_plugin
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# {"qwen": openmind.flow.datasets.template.Template object}
|
||||
|
@ -153,6 +153,11 @@ def zigzag_ring_flash_attn_forward(
|
||||
return out, softmax_max, softmax_sum
|
||||
|
||||
|
||||
class InvalidCausalValueError(Exception):
|
||||
def __str__(self):
|
||||
return "zigzag is meaningless for causal=False"
|
||||
|
||||
|
||||
def zigzag_ring_flash_attn_backward(
|
||||
process_group,
|
||||
dout,
|
||||
@ -167,7 +172,9 @@ def zigzag_ring_flash_attn_backward(
|
||||
dropout_p=0,
|
||||
causal=True,
|
||||
):
|
||||
assert causal is True, "zigzag is meaningless for causal=False"
|
||||
if not causal:
|
||||
raise InvalidCausalValueError()
|
||||
|
||||
kv_comm = RingComm(process_group)
|
||||
d_kv_comm = RingComm(process_group)
|
||||
dq, dk, dv = None, None, None
|
||||
|
@ -22,6 +22,7 @@ from openmind.utils.generic import working_or_temp_dir
|
||||
from openmind.utils.hub import PushToHubMixin, om_hub
|
||||
from tests.constants import OPENMIND_HUB_ENDPOINT
|
||||
from tests.functional.hub.testing_constants import TOKEN, is_network_ok, random_repo_name
|
||||
from tests.utils_for_test import slow
|
||||
|
||||
os.environ["OPENMIND_HUB_ENDPOINT"] = OPENMIND_HUB_ENDPOINT
|
||||
|
||||
@ -43,6 +44,7 @@ DEFAULT_FLAGS = os.O_WRONLY | os.O_CREAT
|
||||
DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSR
|
||||
|
||||
|
||||
@slow
|
||||
@testtools.skipIf(not is_network_ok(), "Can not visit test server, the error code is 418.")
|
||||
class PushToHubMixinTest(TestCase):
|
||||
def setUp(self):
|
||||
@ -111,6 +113,7 @@ class PushToHubMixinTest(TestCase):
|
||||
self.repo_id = repo_id
|
||||
|
||||
|
||||
@slow
|
||||
class HubApiTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -19,13 +19,6 @@ import random
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
is_npu_available = True
|
||||
except ImportError:
|
||||
print("Failed to import torch_npu.")
|
||||
is_npu_available = False
|
||||
import torch.distributed as dist
|
||||
|
||||
from openmind.flow.model.context_parallel.zigzag_ring_flash_attn_varlen import (
|
||||
@ -34,6 +27,14 @@ from openmind.flow.model.context_parallel.zigzag_ring_flash_attn_varlen import (
|
||||
get_sub_seq_lens,
|
||||
)
|
||||
|
||||
is_npu_available = False
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except ImportError:
|
||||
print("Failed to import torch_npu.")
|
||||
else:
|
||||
is_npu_available = True
|
||||
|
||||
|
||||
def extract_softmax_value(softmax_value, cu_seqlens):
|
||||
values = []
|
||||
|
Reference in New Issue
Block a user