Compare commits

...

1 Commits

Author SHA1 Message Date
32edb66a42 修复 rm 性能 2025-05-21 12:11:39 +08:00
5 changed files with 317 additions and 29 deletions

View File

@ -10,7 +10,7 @@
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import argparse
import random
from dataclasses import dataclass
from typing import List, Any, Optional, Union, Dict
@ -51,6 +51,7 @@ class RewardDataCollatorWithPadding:
"""
tokenizer: PreTrainedTokenizerBase
args: argparse.Namespace
padding: Union[bool, str] = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
@ -61,10 +62,12 @@ class RewardDataCollatorWithPadding:
margin = []
# check if we have a margin. If we do, we need to batch it as well
has_margin = "margin" in features[0]
max_length = 0
if self.args.max_length:
max_length = self.args.max_length
else:
max_length = 1024
for feature in features:
# check if the keys are named as expected
max_length = max(max_length, len(feature["input_ids_chosen"]), len(feature["input_ids_rejected"]))
keys_exist = (
"input_ids_chosen" in feature
and "input_ids_rejected" in feature
@ -132,7 +135,7 @@ def run_rm(
train_args = args.reward_args
train_args.remove_unused_columns = False
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer)
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer, args=args)
trainer = reward_trainer.RewardTrainer(
model=model,
args=train_args,

View File

@ -22,6 +22,7 @@ from transformers.models.mistral import modeling_mistral
from openmind.integrations.transformers.npu_fused_ops import attenions, rms_norm, rope, swiglu
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
from openmind.utils.version import check_package_version
@dataclasses.dataclass
@ -40,7 +41,7 @@ def _builtin_patch_flash_attention(RAW_ATTENTION_CLASSES: Dict, NEW_ATTENTION_CL
RAW_ATTENTION_CLASSES.update({k: NEW_ATTENTION_CLASS for k in RAW_ATTENTION_CLASSES})
def __builtin_patch_flash_attention_v2(config):
def __builtin_patch_sdpa(config):
setattr(config, "_attn_implementation", "sdpa")
@ -71,7 +72,7 @@ def _apply_fused_kernel_base(module: ModuleType, **kwargs):
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
elif torch.__version__ >= "2.6.0":
config = kwargs.get("config")
__builtin_patch_flash_attention_v2(config)
__builtin_patch_sdpa(config)
else:
pass
else:
@ -99,6 +100,12 @@ def apply_fused_kernel_qwen2(**kwargs):
_apply_fused_kernel_base(modeling_qwen2, attention=attenions.qwen2.Qwen2NPUAttention, **kwargs)
def apply_fused_kernel_qwen3(**kwargs):
if check_package_version("transformers>=4.51.1") and check_package_version("torch>=2.6.0"):
from transformers.models.qwen3 import modeling_qwen3
_apply_fused_kernel_base(modeling_qwen3, attention=None, **kwargs)
def apply_fused_kernel_llama(**kwargs):
_apply_fused_kernel_base(modeling_llama, attention=attenions.llama.LlamaNpuFusionAttention, **kwargs)

View File

@ -62,15 +62,11 @@ def check_use_fused_kernel(inner=False) -> bool:
return False
# installed version of transformers and torch is not compatible for npu fused options
try:
if torch.__version__ == "2.1.0":
version.require_version("transformers<=4.47.1")
version.require_version("transformers>=4.45.0")
elif torch.__version__ >= "2.6.0":
version.require_version("transformers>=4.51.1")
else:
return False
except ImportError:
if torch.__version__ == "2.1.0" and version.check_package_version("transformers<=4.47.1, >=4.45.0"):
return True
elif torch.__version__ >= "2.6.0" and version.check_package_version("transformers>=4.51.1"):
return True
else:
logger.warning_rank0(
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__} "
f"and the torch v{torch.__version__}. To use npu fused options, if torch version >= 2.6.0, the version of "
@ -78,8 +74,7 @@ def check_use_fused_kernel(inner=False) -> bool:
f"required >= v4.45.0, and <= 4.47.1; In other cases, the npu fused options will not be available. "
)
return False
# check pass
return True
@lru_cache
@ -126,6 +121,7 @@ def apply_fused_kernel(**kwargs):
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs)
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
_apply_fused_kernel_generic(kernel.apply_fused_kernel_mistral, **kwargs)
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm2, **kwargs)
@ -145,6 +141,18 @@ def apply_fused_kernel_to_qwen2(**kwargs):
_apply_log(model_type="qwen2", **kwargs)
def apply_fused_kernel_to_qwen3(**kwargs):
"""
Apply npu fused operators for Qwen2 series models, when call this function, all supported
fusion operators will be enabled by default. You can set the following parameters to disable the
specified fused operator:
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
"""
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs)
_apply_log(model_type="qwen3", **kwargs)
def apply_fused_kernel_to_internlm2(**kwargs):
"""
Apply npu fused operators for Internlm2 series models, when call this function, all supported
@ -195,6 +203,7 @@ def apply_fused_kernel_to_mistral(**kwargs):
SUPPORTED_FUSED_MODELS = {
"Qwen2ForCausalLM": apply_fused_kernel_to_qwen2,
"Qwen3ForCausalLM": apply_fused_kernel_to_qwen3,
"LlamaForCausalLM": apply_fused_kernel_to_llama,
"MistralForCausalLM": apply_fused_kernel_to_mistral,
"InternLM2ForCausalLM": apply_fused_kernel_to_internlm2,
@ -204,6 +213,6 @@ SUPPORTED_FUSED_MODELS = {
def map_fused_kernel_to_model(architecture, **kwargs):
if architecture not in SUPPORTED_FUSED_MODELS:
logger.warning_rank0(f"Unsupported fused model architecture: {architecture}")
logger.warning_rank0(f"Unadapted model architecture for npu fused options: {architecture}, this model will use the default options.")
return
SUPPORTED_FUSED_MODELS.get(architecture)(inner=True, **kwargs)

View File

@ -24,23 +24,26 @@ from typing import Optional
from packaging import version
# Operator mapping for version comparison
_ops_map = {
"<": operator.lt,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
">=": operator.ge,
">": operator.gt,
}
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
ops = {
"<": operator.lt,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
">=": operator.ge,
">": operator.gt,
}
if op not in ops:
if op not in _ops_map:
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
if got_ver is None or want_ver is None:
raise ValueError(
f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider"
f" reinstalling {pkg}."
)
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
if not _ops_map[op](version.parse(got_ver), version.parse(want_ver)):
raise ImportError(
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
)
@ -108,3 +111,148 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
if want_ver is not None:
for op, want_ver in wanted.items():
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
def _parse_requirement_string(requirement_string: str):
"""
Parses a requirement string into a package name and a list of conditions.
Args:
requirement_string (str): The requirement string, e.g.,
"transformers>=4.45.0,<4.50.0" or "numpy" or "python==3.9.1".
Returns:
tuple: (pkg_name, list_of_conditions) or (None, None) on parse failure.
list_of_conditions is a list of (operator_str, version_str) tuples.
Returns (pkg_name, []) if no version is specified (existence check).
"""
cleaned_requirement = requirement_string.strip()
if not cleaned_requirement:
return None, None # Empty requirement string
# Case 1: Only package name (no version specifiers)
# Matches common package name characters: alphanumeric, underscore, dot, hyphen.
simple_pkg_match = re.fullmatch(r"^[a-zA-Z0-9_.\-]+$", cleaned_requirement)
if simple_pkg_match:
return cleaned_requirement, []
# Case 2: Package name with version specifiers
# Regex to separate package name from the rest of the version specifiers.
# Group 1: package_name (e.g., "transformers")
# Group 2: the entire version specification string (e.g., ">=4.45.0,<4.50.0" or " ==3.9.1 ")
# Package names can contain letters, numbers, '.', '_', '-'.
# Version specifiers must start with an operator.
match_pkg_and_spec = re.match(r"^\s*([a-zA-Z0-9_.\-]+)\s*([!=<>]{1,2}.*)\s*$", cleaned_requirement)
if not match_pkg_and_spec:
# print(f"Debug: Failed to parse requirement structure: {cleaned_requirement}")
return None, None
pkg_name = match_pkg_and_spec.group(1)
version_specs_str = match_pkg_and_spec.group(2).strip()
if not pkg_name: # Should not happen with the regex structure
return None, None
conditions = []
# Split by comma for multiple conditions (e.g., ">=1.0, <2.0")
spec_parts = [s.strip() for s in version_specs_str.split(',') if s.strip()]
if not spec_parts and version_specs_str: # e.g. "pkg," or "pkg==,"
# print(f"Debug: Version spec string '{version_specs_str}' yielded no valid parts for '{pkg_name}'")
return None, None
for part in spec_parts:
# Match operator and version for each part
# Group 1: operator (e.g., ">=", "<")
# Group 2: version_string (e.g., "4.45.0")
# Allows spaces around operator and version, which are then stripped.
# Version string is matched non-greedily to handle various characters.
match_condition = re.match(r"^\s*([!=<>]{1,2})\s*(.+?)\s*$", part)
if not match_condition:
# print(f"Debug: Failed to parse condition part '{part}' from '{version_specs_str}'")
return None, None # Malformed condition part
op_str = match_condition.group(1).strip()
ver_str = match_condition.group(2).strip()
if op_str not in _ops_map:
# print(f"Debug: Invalid operator '{op_str}' in condition '{part}'")
return None, None # Invalid operator
if not ver_str:
# print(f"Debug: Empty version string for operator '{op_str}' in condition '{part}'")
return None, None # Empty version string for an operator
conditions.append((op_str, ver_str))
# If version_specs_str was present but no conditions were parsed (e.g. "pkg >=")
if not conditions and version_specs_str:
# print(f"Debug: Version spec string '{version_specs_str}' yielded no valid conditions for '{pkg_name}'")
return None, None
return pkg_name, conditions
def check_package_version(requirement_string: str) -> bool:
"""
Checks if an installed package meets the specified version requirements.
Args:
requirement_string (str): A requirement string in a pip-like format,
e.g., "transformers<=4.51.3, >=4.48.0, !=4.51.2", "numpy==1.20.0",
"requests", "python>=3.8".
Returns:
bool: True if the installed version meets all requirements, False otherwise
(including if package is not found, requirement is malformed,
or version comparison fails).
"""
pkg_name, conditions = _parse_requirement_string(requirement_string)
if pkg_name is None:
# Malformed requirement string or unhandled parsing case.
# print(f"Warning: Could not parse requirement string: '{requirement_string}'")
return False
installed_version_str: str
# Special case for Python version
if pkg_name.lower() == "python":
installed_version_str = ".".join(map(str, sys.version_info[:3]))
else:
# Get installed package version
try:
installed_version_str = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
# Package not found. If there were any conditions or even just an
# existence check (empty conditions list), this is a failure.
return False
# If there are no conditions, it's an existence check.
# If we reached here, the package (or Python) exists.
if not conditions:
return True
# Parse the installed version string
try:
parsed_installed_version = version.parse(installed_version_str)
except version.InvalidVersion:
# Installed version string is invalid, cannot reliably compare.
# print(f"Warning: Installed version '{installed_version_str}' for '{pkg_name}' is invalid.")
return False
# Check all conditions
for op_str, required_version_str in conditions:
try:
parsed_required_version = version.parse(required_version_str)
except version.InvalidVersion:
# Requirement's version string is invalid.
# print(f"Warning: Required version '{required_version_str}' in '{requirement_string}' is invalid.")
return False # Requirement itself is flawed
comparison_func = _ops_map[op_str]
# Perform the comparison
if not comparison_func(parsed_installed_version, parsed_required_version):
return False # This specific condition is not met
# All conditions were met
return True

View File

@ -13,8 +13,9 @@
import unittest
import importlib.metadata
import sys
from openmind.utils.version import require_version
from openmind.utils.version import require_version, check_package_version
class RequireVersionTest(unittest.TestCase):
@ -45,3 +46,123 @@ class RequireVersionTest(unittest.TestCase):
with self.assertRaises(ValueError) as cm:
require_version("numpy?0.0.1")
self.assertIn("requirement needs to be in the pip package format", str(cm.exception))
class MockSys:
def __init__(self, version_info):
self.version_info = version_info
class MockMetadata:
def __init__(self):
self.installed_packages = {}
def version(self, pkg_name):
if pkg_name in self.installed_packages:
return self.installed_packages[pkg_name]
raise importlib.metadata.PackageNotFoundError(f"No package named '{pkg_name}'")
def add_package(self, name, version):
self.installed_packages[name] = version
# Store original functions/objects
class CheckVersionTest(unittest.TestCase):
def setUp(self) -> None:
self.original_importlib_metadata_version = importlib.metadata.version
self.original_sys_version_info = sys.version_info
# Setup mock environment
self.mock_meta = MockMetadata()
importlib.metadata.version = self.mock_meta.version # Monkey patch
def test_version_check(self):
test_cases = [
# Basic comparisons
("transformers==0.5.0", True, ("transformers", "0.5.0"), None),
("transformers==0.5.0", False, ("transformers", "0.6.0"), None),
("transformers>=0.5.0", True, ("transformers", "0.5.0"), None),
("transformers>=0.5.0", True, ("transformers", "0.6.0"), None),
("transformers>=0.5.0", False, ("transformers", "0.4.0"), None),
("transformers>=4.48.0", False, ("transformers", "4.6.0"), None),
("transformers>0.5.0", False, ("transformers", "0.5.0"), None),
("transformers<=4.51.3, >=4.48.0, !=4.51.2", True, ("transformers", "4.51.1"), None),
("transformers<=4.47.1", True, ("transformers", "4.47.1"), None),
("transformers<=4.47.1", False, ("transformers", "4.48.0"), None),
("transformers<0.5.0", True, ("transformers", "0.4.0"), None),
("transformers<0.5.0", False, ("transformers", "0.5.0"), None),
("transformers!=0.5.0", True, ("transformers", "0.6.0"), None),
("transformers!=0.5.0", False, ("transformers", "0.5.0"), None),
# Range comparisons (multiple conditions)
("requests>=2.0.0,<3.0.0", True, ("requests", "2.5.1"), None),
("requests>=2.0.0,<3.0.0", False, ("requests", "3.0.0"), None),
("requests>=2.0.0,<3.0.0", False, ("requests", "1.9.0"), None),
("my-pkg>1.0,<=2.0.0rc1", True, ("my-pkg", "2.0.0rc1"), None),
("my-pkg>1.0,<=2.0.0rc1", True, ("my-pkg", "1.5"), None),
("my-pkg>1.0,<=2.0.0rc1", False, ("my-pkg", "1.0"), None), # >1.0 fails for 1.0
("my-pkg>1.0,<=2.0.0rc1", False, ("my-pkg", "2.0.0"), None), # >2.0.0rc1
# Existence check
("numpy", True, ("numpy", "1.20.0"), None),
("nonexistent_pkg", False, None, None), # Package not found
# Python version check
("python>=3.8.0", True, None, (3, 8, 5)),
("python>=3.8.0", True, None, (3, 9, 1)),
("python>=3.8.0", False, None, (3, 7, 0)),
("python==3.9.1", True, None, (3, 9, 1)),
("python==3.9.1", False, None, (3, 9, 0)),
("python<3.10,>=3.8", True, None, (3, 9, 5)),
("python<3.10,>=3.8", False, None, (3, 7, 5)),
("python<3.10,>=3.8", False, None, (3, 10, 0)),
# Semantic versioning nuances (e.g., 4.5 vs 4.45.0)
("torch>=4.45.0", False, ("torch", "4.5.0"), None), # 4.5.0 is less than 4.45.0
("torch<4.45.0", True, ("torch", "4.5.0"), None),
("torch>=4.5.0", True, ("torch", "4.45.0"), None), # 4.45.0 is greater than 4.5.0
("torch==4.5.0", True, ("torch", "4.5"), None), # 4.5.0 and 4.5 are equivalent by packaging.version
("torch==4.5", True, ("torch", "4.5.0"), None),
# Malformed requirement strings
("some_pkg>>1.0", False, None, None), # Invalid operator
("another_pkg>=1.0,<", False, None, None), # Incomplete condition
("yetanother_pkg>=1.0,invalid<2.0", False, None, None), # Malformed second part
("", False, None, None), # Empty string
("pkg,=1.0", False, None, None), # Malformed
("pkg>=1..0", False, None, None), # Invalid version in requirement
# Package with invalid installed version
("badverpkg>=1.0", False, ("badverpkg", "1.0-invalid-version"), None),
# Spaces and complex names
(" my.pkg-name_with_stuff >= 1.0.0a1 , <1.1.0.post2 ", True, ("my.pkg-name_with_stuff", "1.0.1"), None),
(" my.pkg-name_with_stuff >= 1.0.0a1 , <1.1.0.post2 ", False, ("my.pkg-name_with_stuff", "1.2.0"), None),
]
passed_all = True
for req_str, expected, pkg_setup, py_setup in test_cases:
# Reset mocks for each test
self.mock_meta.installed_packages.clear()
if pkg_setup:
self.mock_meta.add_package(pkg_setup[0], pkg_setup[1])
if py_setup:
sys.version_info = MockSys(py_setup).version_info
else: # Default python version if not specified for test
sys.version_info = MockSys((3, 9, 0)).version_info # A default for non-python tests
# result = check_package_version(req_str)
result = check_package_version(req_str)
if result == expected:
print(f"PASS: '{req_str}' -> {result} (Expected: {expected})")
else:
print(f"FAIL: '{req_str}' -> {result} (Expected: {expected})")
print(f" Setup: pkg={pkg_setup}, py={py_setup}")
passed_all = False
# Restore original functions
importlib.metadata.version = self.original_importlib_metadata_version
sys.version_info = self.original_sys_version_info
self.assertTrue(passed_all)