Compare commits

...

1 Commits

4 changed files with 23 additions and 10 deletions

View File

@ -175,7 +175,7 @@ from .operations import (
send_to_device,
slice_tensors,
)
from .versions import compare_versions, is_torch_version
from .versions import compare_versions, is_torch_version, parse
if is_deepspeed_available():

View File

@ -19,11 +19,9 @@ import warnings
from functools import lru_cache
import torch
from packaging import version
from packaging.version import parse
from .environment import parse_flag_from_env, str_to_bool
from .versions import compare_versions, is_torch_version
from .versions import compare_versions, is_torch_version, parse
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
@ -180,7 +178,7 @@ def is_deepspeed_available():
def is_pippy_available():
package_exists = _is_package_available("pippy", "torchpippy")
if package_exists:
pippy_version = version.parse(importlib.metadata.version("torchpippy"))
pippy_version = parse(importlib.metadata.version("torchpippy"))
return compare_versions(pippy_version, ">", "0.1.1")
return False
@ -199,7 +197,7 @@ def is_bf16_available(ignore_tpu=False):
def is_4bit_bnb_available():
package_exists = _is_package_available("bitsandbytes")
if package_exists:
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
return compare_versions(bnb_version, ">=", "0.39.0")
return False
@ -207,7 +205,7 @@ def is_4bit_bnb_available():
def is_8bit_bnb_available():
package_exists = _is_package_available("bitsandbytes")
if package_exists:
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
return compare_versions(bnb_version, ">=", "0.37.2")
return False
@ -255,7 +253,7 @@ def is_triton_available():
def is_aim_available():
package_exists = _is_package_available("aim")
if package_exists:
aim_version = version.parse(importlib.metadata.version("aim"))
aim_version = parse(importlib.metadata.version("aim"))
return compare_versions(aim_version, "<", "4.0.0")
return False
@ -324,7 +322,7 @@ def is_mps_available(min_version="1.12"):
def is_ipex_available():
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
return str(parse(full_version).major) + "." + str(parse(full_version).minor)
_torch_version = importlib.metadata.version("torch")
if importlib.util.find_spec("intel_extension_for_pytorch") is None:

View File

@ -15,11 +15,20 @@
import importlib.metadata
from typing import Union
from packaging.version import Version, parse
from packaging.version import Version
from packaging.version import parse as _parse
from .constants import STR_OPERATION_TO_FUNC
def parse(version: str):
"""
Same as `packaging.version.parse`, but grabs strictly the base version.
"""
version = _parse(version)
return _parse(version.base_version)
torch_version = parse(importlib.metadata.version("torch"))

View File

@ -49,6 +49,7 @@ from accelerate.utils import (
listify,
pad_across_processes,
pad_input_tensors,
parse,
patch_environment,
recursively_apply,
save,
@ -411,3 +412,8 @@ class UtilsTester(unittest.TestCase):
tqdm(True, range(3), disable=True)
assert "Passing `True` as the first argument to" in cm.pop().message.args[0]
tqdm(range(3), main_process_only=True, disable=True)
def test_dev0_parsing(self):
v1 = parse("0.34.0.dev0")
v2 = parse("0.34.0")
assert v1 == v2