Compare commits
1 Commits
master
...
master-520
Author | SHA1 | Date | |
---|---|---|---|
32edb66a42 |
@ -10,7 +10,7 @@
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
import argparse
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Optional, Union, Dict
|
||||
@ -51,6 +51,7 @@ class RewardDataCollatorWithPadding:
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
args: argparse.Namespace
|
||||
padding: Union[bool, str] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
@ -61,10 +62,12 @@ class RewardDataCollatorWithPadding:
|
||||
margin = []
|
||||
# check if we have a margin. If we do, we need to batch it as well
|
||||
has_margin = "margin" in features[0]
|
||||
max_length = 0
|
||||
if self.args.max_length:
|
||||
max_length = self.args.max_length
|
||||
else:
|
||||
max_length = 1024
|
||||
for feature in features:
|
||||
# check if the keys are named as expected
|
||||
max_length = max(max_length, len(feature["input_ids_chosen"]), len(feature["input_ids_rejected"]))
|
||||
keys_exist = (
|
||||
"input_ids_chosen" in feature
|
||||
and "input_ids_rejected" in feature
|
||||
@ -132,7 +135,7 @@ def run_rm(
|
||||
train_args = args.reward_args
|
||||
train_args.remove_unused_columns = False
|
||||
|
||||
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer)
|
||||
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer, args=args)
|
||||
trainer = reward_trainer.RewardTrainer(
|
||||
model=model,
|
||||
args=train_args,
|
||||
|
@ -22,6 +22,7 @@ from transformers.models.mistral import modeling_mistral
|
||||
|
||||
from openmind.integrations.transformers.npu_fused_ops import attenions, rms_norm, rope, swiglu
|
||||
from openmind.integrations.transformers.npu_fused_ops import dynamic_module_utils
|
||||
from openmind.utils.version import check_package_version
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -40,7 +41,7 @@ def _builtin_patch_flash_attention(RAW_ATTENTION_CLASSES: Dict, NEW_ATTENTION_CL
|
||||
RAW_ATTENTION_CLASSES.update({k: NEW_ATTENTION_CLASS for k in RAW_ATTENTION_CLASSES})
|
||||
|
||||
|
||||
def __builtin_patch_flash_attention_v2(config):
|
||||
def __builtin_patch_sdpa(config):
|
||||
setattr(config, "_attn_implementation", "sdpa")
|
||||
|
||||
|
||||
@ -71,7 +72,7 @@ def _apply_fused_kernel_base(module: ModuleType, **kwargs):
|
||||
_builtin_patch_flash_attention(getattr(module, attention_classes_attr), attention)
|
||||
elif torch.__version__ >= "2.6.0":
|
||||
config = kwargs.get("config")
|
||||
__builtin_patch_flash_attention_v2(config)
|
||||
__builtin_patch_sdpa(config)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
@ -99,6 +100,12 @@ def apply_fused_kernel_qwen2(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_qwen2, attention=attenions.qwen2.Qwen2NPUAttention, **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_qwen3(**kwargs):
|
||||
if check_package_version("transformers>=4.51.1") and check_package_version("torch>=2.6.0"):
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
_apply_fused_kernel_base(modeling_qwen3, attention=None, **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_llama(**kwargs):
|
||||
_apply_fused_kernel_base(modeling_llama, attention=attenions.llama.LlamaNpuFusionAttention, **kwargs)
|
||||
|
||||
|
@ -62,15 +62,11 @@ def check_use_fused_kernel(inner=False) -> bool:
|
||||
return False
|
||||
|
||||
# installed version of transformers and torch is not compatible for npu fused options
|
||||
try:
|
||||
if torch.__version__ == "2.1.0":
|
||||
version.require_version("transformers<=4.47.1")
|
||||
version.require_version("transformers>=4.45.0")
|
||||
elif torch.__version__ >= "2.6.0":
|
||||
version.require_version("transformers>=4.51.1")
|
||||
else:
|
||||
return False
|
||||
except ImportError:
|
||||
if torch.__version__ == "2.1.0" and version.check_package_version("transformers<=4.47.1, >=4.45.0"):
|
||||
return True
|
||||
elif torch.__version__ >= "2.6.0" and version.check_package_version("transformers>=4.51.1"):
|
||||
return True
|
||||
else:
|
||||
logger.warning_rank0(
|
||||
f"RuntimeWarning: The npu fused options is not available under the transformers v{transformers.__version__} "
|
||||
f"and the torch v{torch.__version__}. To use npu fused options, if torch version >= 2.6.0, the version of "
|
||||
@ -78,8 +74,7 @@ def check_use_fused_kernel(inner=False) -> bool:
|
||||
f"required >= v4.45.0, and <= 4.47.1; In other cases, the npu fused options will not be available. "
|
||||
)
|
||||
return False
|
||||
# check pass
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@lru_cache
|
||||
@ -126,6 +121,7 @@ def apply_fused_kernel(**kwargs):
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen2, **kwargs)
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs)
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_llama, **kwargs)
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_mistral, **kwargs)
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_internlm2, **kwargs)
|
||||
@ -145,6 +141,18 @@ def apply_fused_kernel_to_qwen2(**kwargs):
|
||||
_apply_log(model_type="qwen2", **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_to_qwen3(**kwargs):
|
||||
"""
|
||||
Apply npu fused operators for Qwen2 series models, when call this function, all supported
|
||||
fusion operators will be enabled by default. You can set the following parameters to disable the
|
||||
specified fused operator:
|
||||
`use_npu_fusion_attention: bool = False`, default is True, set it to `False` to disable npu fusion attention.
|
||||
`use_fused_rms_norm: bool = False`, default is True, set it to `False` to disable npu RMSNorm.
|
||||
"""
|
||||
_apply_fused_kernel_generic(kernel.apply_fused_kernel_qwen3, **kwargs)
|
||||
_apply_log(model_type="qwen3", **kwargs)
|
||||
|
||||
|
||||
def apply_fused_kernel_to_internlm2(**kwargs):
|
||||
"""
|
||||
Apply npu fused operators for Internlm2 series models, when call this function, all supported
|
||||
@ -195,6 +203,7 @@ def apply_fused_kernel_to_mistral(**kwargs):
|
||||
|
||||
SUPPORTED_FUSED_MODELS = {
|
||||
"Qwen2ForCausalLM": apply_fused_kernel_to_qwen2,
|
||||
"Qwen3ForCausalLM": apply_fused_kernel_to_qwen3,
|
||||
"LlamaForCausalLM": apply_fused_kernel_to_llama,
|
||||
"MistralForCausalLM": apply_fused_kernel_to_mistral,
|
||||
"InternLM2ForCausalLM": apply_fused_kernel_to_internlm2,
|
||||
@ -204,6 +213,6 @@ SUPPORTED_FUSED_MODELS = {
|
||||
|
||||
def map_fused_kernel_to_model(architecture, **kwargs):
|
||||
if architecture not in SUPPORTED_FUSED_MODELS:
|
||||
logger.warning_rank0(f"Unsupported fused model architecture: {architecture}")
|
||||
logger.warning_rank0(f"Unadapted model architecture for npu fused options: {architecture}, this model will use the default options.")
|
||||
return
|
||||
SUPPORTED_FUSED_MODELS.get(architecture)(inner=True, **kwargs)
|
||||
|
@ -24,23 +24,26 @@ from typing import Optional
|
||||
from packaging import version
|
||||
|
||||
|
||||
# Operator mapping for version comparison
|
||||
_ops_map = {
|
||||
"<": operator.lt,
|
||||
"<=": operator.le,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
">=": operator.ge,
|
||||
">": operator.gt,
|
||||
}
|
||||
|
||||
|
||||
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
|
||||
ops = {
|
||||
"<": operator.lt,
|
||||
"<=": operator.le,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
">=": operator.ge,
|
||||
">": operator.gt,
|
||||
}
|
||||
if op not in ops:
|
||||
if op not in _ops_map:
|
||||
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
|
||||
if got_ver is None or want_ver is None:
|
||||
raise ValueError(
|
||||
f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider"
|
||||
f" reinstalling {pkg}."
|
||||
)
|
||||
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
||||
if not _ops_map[op](version.parse(got_ver), version.parse(want_ver)):
|
||||
raise ImportError(
|
||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
||||
)
|
||||
@ -108,3 +111,148 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
||||
if want_ver is not None:
|
||||
for op, want_ver in wanted.items():
|
||||
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
||||
|
||||
|
||||
def _parse_requirement_string(requirement_string: str):
|
||||
"""
|
||||
Parses a requirement string into a package name and a list of conditions.
|
||||
|
||||
Args:
|
||||
requirement_string (str): The requirement string, e.g.,
|
||||
"transformers>=4.45.0,<4.50.0" or "numpy" or "python==3.9.1".
|
||||
|
||||
Returns:
|
||||
tuple: (pkg_name, list_of_conditions) or (None, None) on parse failure.
|
||||
list_of_conditions is a list of (operator_str, version_str) tuples.
|
||||
Returns (pkg_name, []) if no version is specified (existence check).
|
||||
"""
|
||||
cleaned_requirement = requirement_string.strip()
|
||||
if not cleaned_requirement:
|
||||
return None, None # Empty requirement string
|
||||
|
||||
# Case 1: Only package name (no version specifiers)
|
||||
# Matches common package name characters: alphanumeric, underscore, dot, hyphen.
|
||||
simple_pkg_match = re.fullmatch(r"^[a-zA-Z0-9_.\-]+$", cleaned_requirement)
|
||||
if simple_pkg_match:
|
||||
return cleaned_requirement, []
|
||||
|
||||
# Case 2: Package name with version specifiers
|
||||
# Regex to separate package name from the rest of the version specifiers.
|
||||
# Group 1: package_name (e.g., "transformers")
|
||||
# Group 2: the entire version specification string (e.g., ">=4.45.0,<4.50.0" or " ==3.9.1 ")
|
||||
# Package names can contain letters, numbers, '.', '_', '-'.
|
||||
# Version specifiers must start with an operator.
|
||||
match_pkg_and_spec = re.match(r"^\s*([a-zA-Z0-9_.\-]+)\s*([!=<>]{1,2}.*)\s*$", cleaned_requirement)
|
||||
if not match_pkg_and_spec:
|
||||
# print(f"Debug: Failed to parse requirement structure: {cleaned_requirement}")
|
||||
return None, None
|
||||
|
||||
pkg_name = match_pkg_and_spec.group(1)
|
||||
version_specs_str = match_pkg_and_spec.group(2).strip()
|
||||
|
||||
if not pkg_name: # Should not happen with the regex structure
|
||||
return None, None
|
||||
|
||||
conditions = []
|
||||
# Split by comma for multiple conditions (e.g., ">=1.0, <2.0")
|
||||
spec_parts = [s.strip() for s in version_specs_str.split(',') if s.strip()]
|
||||
|
||||
if not spec_parts and version_specs_str: # e.g. "pkg," or "pkg==,"
|
||||
# print(f"Debug: Version spec string '{version_specs_str}' yielded no valid parts for '{pkg_name}'")
|
||||
return None, None
|
||||
|
||||
for part in spec_parts:
|
||||
# Match operator and version for each part
|
||||
# Group 1: operator (e.g., ">=", "<")
|
||||
# Group 2: version_string (e.g., "4.45.0")
|
||||
# Allows spaces around operator and version, which are then stripped.
|
||||
# Version string is matched non-greedily to handle various characters.
|
||||
match_condition = re.match(r"^\s*([!=<>]{1,2})\s*(.+?)\s*$", part)
|
||||
if not match_condition:
|
||||
# print(f"Debug: Failed to parse condition part '{part}' from '{version_specs_str}'")
|
||||
return None, None # Malformed condition part
|
||||
|
||||
op_str = match_condition.group(1).strip()
|
||||
ver_str = match_condition.group(2).strip()
|
||||
|
||||
if op_str not in _ops_map:
|
||||
# print(f"Debug: Invalid operator '{op_str}' in condition '{part}'")
|
||||
return None, None # Invalid operator
|
||||
if not ver_str:
|
||||
# print(f"Debug: Empty version string for operator '{op_str}' in condition '{part}'")
|
||||
return None, None # Empty version string for an operator
|
||||
|
||||
conditions.append((op_str, ver_str))
|
||||
|
||||
# If version_specs_str was present but no conditions were parsed (e.g. "pkg >=")
|
||||
if not conditions and version_specs_str:
|
||||
# print(f"Debug: Version spec string '{version_specs_str}' yielded no valid conditions for '{pkg_name}'")
|
||||
return None, None
|
||||
|
||||
return pkg_name, conditions
|
||||
|
||||
|
||||
def check_package_version(requirement_string: str) -> bool:
|
||||
"""
|
||||
Checks if an installed package meets the specified version requirements.
|
||||
|
||||
Args:
|
||||
requirement_string (str): A requirement string in a pip-like format,
|
||||
e.g., "transformers<=4.51.3, >=4.48.0, !=4.51.2", "numpy==1.20.0",
|
||||
"requests", "python>=3.8".
|
||||
|
||||
Returns:
|
||||
bool: True if the installed version meets all requirements, False otherwise
|
||||
(including if package is not found, requirement is malformed,
|
||||
or version comparison fails).
|
||||
"""
|
||||
pkg_name, conditions = _parse_requirement_string(requirement_string)
|
||||
|
||||
if pkg_name is None:
|
||||
# Malformed requirement string or unhandled parsing case.
|
||||
# print(f"Warning: Could not parse requirement string: '{requirement_string}'")
|
||||
return False
|
||||
|
||||
installed_version_str: str
|
||||
# Special case for Python version
|
||||
if pkg_name.lower() == "python":
|
||||
installed_version_str = ".".join(map(str, sys.version_info[:3]))
|
||||
else:
|
||||
# Get installed package version
|
||||
try:
|
||||
installed_version_str = importlib.metadata.version(pkg_name)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# Package not found. If there were any conditions or even just an
|
||||
# existence check (empty conditions list), this is a failure.
|
||||
return False
|
||||
|
||||
# If there are no conditions, it's an existence check.
|
||||
# If we reached here, the package (or Python) exists.
|
||||
if not conditions:
|
||||
return True
|
||||
|
||||
# Parse the installed version string
|
||||
try:
|
||||
parsed_installed_version = version.parse(installed_version_str)
|
||||
except version.InvalidVersion:
|
||||
# Installed version string is invalid, cannot reliably compare.
|
||||
# print(f"Warning: Installed version '{installed_version_str}' for '{pkg_name}' is invalid.")
|
||||
return False
|
||||
|
||||
# Check all conditions
|
||||
for op_str, required_version_str in conditions:
|
||||
try:
|
||||
parsed_required_version = version.parse(required_version_str)
|
||||
except version.InvalidVersion:
|
||||
# Requirement's version string is invalid.
|
||||
# print(f"Warning: Required version '{required_version_str}' in '{requirement_string}' is invalid.")
|
||||
return False # Requirement itself is flawed
|
||||
|
||||
comparison_func = _ops_map[op_str]
|
||||
|
||||
# Perform the comparison
|
||||
if not comparison_func(parsed_installed_version, parsed_required_version):
|
||||
return False # This specific condition is not met
|
||||
|
||||
# All conditions were met
|
||||
return True
|
@ -13,8 +13,9 @@
|
||||
|
||||
import unittest
|
||||
import importlib.metadata
|
||||
import sys
|
||||
|
||||
from openmind.utils.version import require_version
|
||||
from openmind.utils.version import require_version, check_package_version
|
||||
|
||||
|
||||
class RequireVersionTest(unittest.TestCase):
|
||||
@ -45,3 +46,123 @@ class RequireVersionTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
require_version("numpy?0.0.1")
|
||||
self.assertIn("requirement needs to be in the pip package format", str(cm.exception))
|
||||
|
||||
|
||||
class MockSys:
|
||||
def __init__(self, version_info):
|
||||
self.version_info = version_info
|
||||
|
||||
|
||||
class MockMetadata:
|
||||
def __init__(self):
|
||||
self.installed_packages = {}
|
||||
|
||||
def version(self, pkg_name):
|
||||
if pkg_name in self.installed_packages:
|
||||
return self.installed_packages[pkg_name]
|
||||
raise importlib.metadata.PackageNotFoundError(f"No package named '{pkg_name}'")
|
||||
|
||||
def add_package(self, name, version):
|
||||
self.installed_packages[name] = version
|
||||
|
||||
# Store original functions/objects
|
||||
class CheckVersionTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.original_importlib_metadata_version = importlib.metadata.version
|
||||
self.original_sys_version_info = sys.version_info
|
||||
|
||||
# Setup mock environment
|
||||
self.mock_meta = MockMetadata()
|
||||
importlib.metadata.version = self.mock_meta.version # Monkey patch
|
||||
|
||||
def test_version_check(self):
|
||||
test_cases = [
|
||||
# Basic comparisons
|
||||
("transformers==0.5.0", True, ("transformers", "0.5.0"), None),
|
||||
("transformers==0.5.0", False, ("transformers", "0.6.0"), None),
|
||||
("transformers>=0.5.0", True, ("transformers", "0.5.0"), None),
|
||||
("transformers>=0.5.0", True, ("transformers", "0.6.0"), None),
|
||||
("transformers>=0.5.0", False, ("transformers", "0.4.0"), None),
|
||||
("transformers>=4.48.0", False, ("transformers", "4.6.0"), None),
|
||||
("transformers>0.5.0", False, ("transformers", "0.5.0"), None),
|
||||
("transformers<=4.51.3, >=4.48.0, !=4.51.2", True, ("transformers", "4.51.1"), None),
|
||||
("transformers<=4.47.1", True, ("transformers", "4.47.1"), None),
|
||||
("transformers<=4.47.1", False, ("transformers", "4.48.0"), None),
|
||||
("transformers<0.5.0", True, ("transformers", "0.4.0"), None),
|
||||
("transformers<0.5.0", False, ("transformers", "0.5.0"), None),
|
||||
("transformers!=0.5.0", True, ("transformers", "0.6.0"), None),
|
||||
("transformers!=0.5.0", False, ("transformers", "0.5.0"), None),
|
||||
|
||||
# Range comparisons (multiple conditions)
|
||||
("requests>=2.0.0,<3.0.0", True, ("requests", "2.5.1"), None),
|
||||
("requests>=2.0.0,<3.0.0", False, ("requests", "3.0.0"), None),
|
||||
("requests>=2.0.0,<3.0.0", False, ("requests", "1.9.0"), None),
|
||||
("my-pkg>1.0,<=2.0.0rc1", True, ("my-pkg", "2.0.0rc1"), None),
|
||||
("my-pkg>1.0,<=2.0.0rc1", True, ("my-pkg", "1.5"), None),
|
||||
("my-pkg>1.0,<=2.0.0rc1", False, ("my-pkg", "1.0"), None), # >1.0 fails for 1.0
|
||||
("my-pkg>1.0,<=2.0.0rc1", False, ("my-pkg", "2.0.0"), None), # >2.0.0rc1
|
||||
|
||||
# Existence check
|
||||
("numpy", True, ("numpy", "1.20.0"), None),
|
||||
("nonexistent_pkg", False, None, None), # Package not found
|
||||
|
||||
# Python version check
|
||||
("python>=3.8.0", True, None, (3, 8, 5)),
|
||||
("python>=3.8.0", True, None, (3, 9, 1)),
|
||||
("python>=3.8.0", False, None, (3, 7, 0)),
|
||||
("python==3.9.1", True, None, (3, 9, 1)),
|
||||
("python==3.9.1", False, None, (3, 9, 0)),
|
||||
("python<3.10,>=3.8", True, None, (3, 9, 5)),
|
||||
("python<3.10,>=3.8", False, None, (3, 7, 5)),
|
||||
("python<3.10,>=3.8", False, None, (3, 10, 0)),
|
||||
|
||||
# Semantic versioning nuances (e.g., 4.5 vs 4.45.0)
|
||||
("torch>=4.45.0", False, ("torch", "4.5.0"), None), # 4.5.0 is less than 4.45.0
|
||||
("torch<4.45.0", True, ("torch", "4.5.0"), None),
|
||||
("torch>=4.5.0", True, ("torch", "4.45.0"), None), # 4.45.0 is greater than 4.5.0
|
||||
("torch==4.5.0", True, ("torch", "4.5"), None), # 4.5.0 and 4.5 are equivalent by packaging.version
|
||||
("torch==4.5", True, ("torch", "4.5.0"), None),
|
||||
|
||||
# Malformed requirement strings
|
||||
("some_pkg>>1.0", False, None, None), # Invalid operator
|
||||
("another_pkg>=1.0,<", False, None, None), # Incomplete condition
|
||||
("yetanother_pkg>=1.0,invalid<2.0", False, None, None), # Malformed second part
|
||||
("", False, None, None), # Empty string
|
||||
("pkg,=1.0", False, None, None), # Malformed
|
||||
("pkg>=1..0", False, None, None), # Invalid version in requirement
|
||||
|
||||
# Package with invalid installed version
|
||||
("badverpkg>=1.0", False, ("badverpkg", "1.0-invalid-version"), None),
|
||||
|
||||
# Spaces and complex names
|
||||
(" my.pkg-name_with_stuff >= 1.0.0a1 , <1.1.0.post2 ", True, ("my.pkg-name_with_stuff", "1.0.1"), None),
|
||||
(" my.pkg-name_with_stuff >= 1.0.0a1 , <1.1.0.post2 ", False, ("my.pkg-name_with_stuff", "1.2.0"), None),
|
||||
]
|
||||
|
||||
passed_all = True
|
||||
for req_str, expected, pkg_setup, py_setup in test_cases:
|
||||
# Reset mocks for each test
|
||||
self.mock_meta.installed_packages.clear()
|
||||
if pkg_setup:
|
||||
self.mock_meta.add_package(pkg_setup[0], pkg_setup[1])
|
||||
if py_setup:
|
||||
sys.version_info = MockSys(py_setup).version_info
|
||||
else: # Default python version if not specified for test
|
||||
sys.version_info = MockSys((3, 9, 0)).version_info # A default for non-python tests
|
||||
|
||||
# result = check_package_version(req_str)
|
||||
result = check_package_version(req_str)
|
||||
|
||||
if result == expected:
|
||||
print(f"PASS: '{req_str}' -> {result} (Expected: {expected})")
|
||||
else:
|
||||
print(f"FAIL: '{req_str}' -> {result} (Expected: {expected})")
|
||||
print(f" Setup: pkg={pkg_setup}, py={py_setup}")
|
||||
passed_all = False
|
||||
|
||||
# Restore original functions
|
||||
importlib.metadata.version = self.original_importlib_metadata_version
|
||||
sys.version_info = self.original_sys_version_info
|
||||
|
||||
self.assertTrue(passed_all)
|
||||
|
Reference in New Issue
Block a user