mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Compare commits
2 Commits
87b90f045e
...
v0.11.1
Author | SHA1 | Date | |
---|---|---|---|
207376de62 | |||
2276c6e190 |
2
setup.py
2
setup.py
@ -15,7 +15,7 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
VERSION = "0.11.0"
|
||||
VERSION = "0.11.1"
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = [
|
||||
|
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.11.0"
|
||||
__version__ = "0.11.1"
|
||||
|
||||
from .auto import (
|
||||
AutoPeftModel,
|
||||
|
@ -20,6 +20,7 @@ from __future__ import annotations
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -31,13 +32,46 @@ from torch.utils.cpp_extension import load
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
|
||||
|
||||
os.environ["CC"] = "gcc"
|
||||
os.environ["CXX"] = "gcc"
|
||||
curr_dir = os.path.dirname(__file__)
|
||||
|
||||
_FBD_CUDA = None
|
||||
|
||||
|
||||
# this function is a 1:1 copy from accelerate
|
||||
@contextmanager
|
||||
def patch_environment(**kwargs):
|
||||
"""
|
||||
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
|
||||
|
||||
Will convert the values in `kwargs` to strings and upper-case all the keys.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import os
|
||||
>>> from accelerate.utils import patch_environment
|
||||
|
||||
>>> with patch_environment(FOO="bar"):
|
||||
... print(os.environ["FOO"]) # prints "bar"
|
||||
>>> print(os.environ["FOO"]) # raises KeyError
|
||||
```
|
||||
"""
|
||||
existing_vars = {}
|
||||
for key, value in kwargs.items():
|
||||
key = key.upper()
|
||||
if key in os.environ:
|
||||
existing_vars[key] = os.environ[key]
|
||||
os.environ[key] = str(value)
|
||||
|
||||
yield
|
||||
|
||||
for key in kwargs:
|
||||
key = key.upper()
|
||||
if key in existing_vars:
|
||||
# restore previous value
|
||||
os.environ[key] = existing_vars[key]
|
||||
else:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
|
||||
def get_fbd_cuda():
|
||||
global _FBD_CUDA
|
||||
|
||||
@ -47,14 +81,15 @@ def get_fbd_cuda():
|
||||
curr_dir = os.path.dirname(__file__)
|
||||
# need ninja to build the extension
|
||||
try:
|
||||
fbd_cuda = load(
|
||||
name="fbd_cuda",
|
||||
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
|
||||
verbose=True,
|
||||
# build_directory='/tmp/' # for debugging
|
||||
)
|
||||
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
|
||||
import fbd_cuda
|
||||
with patch_environment(CC="gcc", CXX="gcc"):
|
||||
fbd_cuda = load(
|
||||
name="fbd_cuda",
|
||||
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
|
||||
verbose=True,
|
||||
# build_directory='/tmp/' # for debugging
|
||||
)
|
||||
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
|
||||
import fbd_cuda
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to load the CUDA extension: {e}, check if ninja is available.")
|
||||
warnings.warn("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process.")
|
||||
|
Reference in New Issue
Block a user