mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-18 00:14:36 +08:00
Compare commits
1 Commits
v1.11.0
...
make-versi
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a04b8b58e |
@ -175,7 +175,7 @@ from .operations import (
|
||||
send_to_device,
|
||||
slice_tensors,
|
||||
)
|
||||
from .versions import compare_versions, is_torch_version
|
||||
from .versions import compare_versions, is_torch_version, parse
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
|
||||
@ -19,11 +19,9 @@ import warnings
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from packaging.version import parse
|
||||
|
||||
from .environment import parse_flag_from_env, str_to_bool
|
||||
from .versions import compare_versions, is_torch_version
|
||||
from .versions import compare_versions, is_torch_version, parse
|
||||
|
||||
|
||||
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
|
||||
@ -180,7 +178,7 @@ def is_deepspeed_available():
|
||||
def is_pippy_available():
|
||||
package_exists = _is_package_available("pippy", "torchpippy")
|
||||
if package_exists:
|
||||
pippy_version = version.parse(importlib.metadata.version("torchpippy"))
|
||||
pippy_version = parse(importlib.metadata.version("torchpippy"))
|
||||
return compare_versions(pippy_version, ">", "0.1.1")
|
||||
return False
|
||||
|
||||
@ -199,7 +197,7 @@ def is_bf16_available(ignore_tpu=False):
|
||||
def is_4bit_bnb_available():
|
||||
package_exists = _is_package_available("bitsandbytes")
|
||||
if package_exists:
|
||||
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
|
||||
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
|
||||
return compare_versions(bnb_version, ">=", "0.39.0")
|
||||
return False
|
||||
|
||||
@ -207,7 +205,7 @@ def is_4bit_bnb_available():
|
||||
def is_8bit_bnb_available():
|
||||
package_exists = _is_package_available("bitsandbytes")
|
||||
if package_exists:
|
||||
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
|
||||
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
|
||||
return compare_versions(bnb_version, ">=", "0.37.2")
|
||||
return False
|
||||
|
||||
@ -255,7 +253,7 @@ def is_triton_available():
|
||||
def is_aim_available():
|
||||
package_exists = _is_package_available("aim")
|
||||
if package_exists:
|
||||
aim_version = version.parse(importlib.metadata.version("aim"))
|
||||
aim_version = parse(importlib.metadata.version("aim"))
|
||||
return compare_versions(aim_version, "<", "4.0.0")
|
||||
return False
|
||||
|
||||
@ -324,7 +322,7 @@ def is_mps_available(min_version="1.12"):
|
||||
|
||||
def is_ipex_available():
|
||||
def get_major_and_minor_from_version(full_version):
|
||||
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
||||
return str(parse(full_version).major) + "." + str(parse(full_version).minor)
|
||||
|
||||
_torch_version = importlib.metadata.version("torch")
|
||||
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
||||
|
||||
@ -15,11 +15,20 @@
|
||||
import importlib.metadata
|
||||
from typing import Union
|
||||
|
||||
from packaging.version import Version, parse
|
||||
from packaging.version import Version
|
||||
from packaging.version import parse as _parse
|
||||
|
||||
from .constants import STR_OPERATION_TO_FUNC
|
||||
|
||||
|
||||
def parse(version: str):
|
||||
"""
|
||||
Same as `packaging.version.parse`, but grabs strictly the base version.
|
||||
"""
|
||||
version = _parse(version)
|
||||
return _parse(version.base_version)
|
||||
|
||||
|
||||
torch_version = parse(importlib.metadata.version("torch"))
|
||||
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ from accelerate.utils import (
|
||||
listify,
|
||||
pad_across_processes,
|
||||
pad_input_tensors,
|
||||
parse,
|
||||
patch_environment,
|
||||
recursively_apply,
|
||||
save,
|
||||
@ -411,3 +412,8 @@ class UtilsTester(unittest.TestCase):
|
||||
tqdm(True, range(3), disable=True)
|
||||
assert "Passing `True` as the first argument to" in cm.pop().message.args[0]
|
||||
tqdm(range(3), main_process_only=True, disable=True)
|
||||
|
||||
def test_dev0_parsing(self):
|
||||
v1 = parse("0.34.0.dev0")
|
||||
v2 = parse("0.34.0")
|
||||
assert v1 == v2
|
||||
|
||||
Reference in New Issue
Block a user