!247 [master]codecheck整改

Merge pull request !247 from 张烨槟/master
This commit is contained in:
张烨槟
2025-06-26 11:05:58 +00:00
committed by i-robot
parent 042384994a
commit 5b8d1d4bf2
5 changed files with 21 additions and 11 deletions

View File

@ -61,7 +61,6 @@ def convert_pairwise(examples, datasets_attr: InstructionDatasetAttr):
prompt.append({"role": "user", "content": "\n".join(content)})
if examples[datasets_attr.chosen][i] and examples[datasets_attr.rejected][i]:
# response.append([examples[datasets_attr.chosen][i], examples[datasets_attr.rejected][i]])
response.append(
[
{"role": "assistant", "content": examples[datasets_attr.chosen][i]},

View File

@ -23,14 +23,14 @@ import os
from pathlib import Path
from typing import Dict, List, Sequence, Union, Optional
from transformers import PreTrainedTokenizer
from openmind.flow.model.model_registry import SUPPORTED_MODELS
from openmind.utils.constants import Tokens
from openmind.utils import get_logger
from openmind.flow.arguments import get_args
from openmind.flow.datasets.mm_plugin import BasePlugin, parse_mm_plugin
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
# {"qwen": openmind.flow.datasets.template.Template object}

View File

@ -153,6 +153,11 @@ def zigzag_ring_flash_attn_forward(
return out, softmax_max, softmax_sum
class InvalidCausalValueError(Exception):
def __str__(self):
return "zigzag is meaningless for causal=False"
def zigzag_ring_flash_attn_backward(
process_group,
dout,
@ -167,7 +172,9 @@ def zigzag_ring_flash_attn_backward(
dropout_p=0,
causal=True,
):
assert causal is True, "zigzag is meaningless for causal=False"
if not causal:
raise InvalidCausalValueError()
kv_comm = RingComm(process_group)
d_kv_comm = RingComm(process_group)
dq, dk, dv = None, None, None

View File

@ -22,6 +22,7 @@ from openmind.utils.generic import working_or_temp_dir
from openmind.utils.hub import PushToHubMixin, om_hub
from tests.constants import OPENMIND_HUB_ENDPOINT
from tests.functional.hub.testing_constants import TOKEN, is_network_ok, random_repo_name
from tests.utils_for_test import slow
os.environ["OPENMIND_HUB_ENDPOINT"] = OPENMIND_HUB_ENDPOINT
@ -43,6 +44,7 @@ DEFAULT_FLAGS = os.O_WRONLY | os.O_CREAT
DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSR
@slow
@testtools.skipIf(not is_network_ok(), "Can not visit test server, the error code is 418.")
class PushToHubMixinTest(TestCase):
def setUp(self):
@ -111,6 +113,7 @@ class PushToHubMixinTest(TestCase):
self.repo_id = repo_id
@slow
class HubApiTest(TestCase):
def setUp(self):
super().setUp()

View File

@ -19,13 +19,6 @@ import random
import torch
try:
import torch_npu # noqa: F401
is_npu_available = True
except ImportError:
print("Failed to import torch_npu.")
is_npu_available = False
import torch.distributed as dist
from openmind.flow.model.context_parallel.zigzag_ring_flash_attn_varlen import (
@ -34,6 +27,14 @@ from openmind.flow.model.context_parallel.zigzag_ring_flash_attn_varlen import (
get_sub_seq_lens,
)
is_npu_available = False
try:
import torch_npu # noqa: F401
except ImportError:
print("Failed to import torch_npu.")
else:
is_npu_available = True
def extract_softmax_value(softmax_value, cu_seqlens):
values = []