Compare commits

...

2 Commits

Author SHA1 Message Date
207376de62 Release v0.11.1 2024-05-17 12:48:49 +02:00
2276c6e190 FIX BOFT setting env vars breaks C++ compilation (#1739)
Resolves #1738
2024-05-17 12:45:46 +02:00
3 changed files with 49 additions and 14 deletions

View File

@ -15,7 +15,7 @@
from setuptools import find_packages, setup
VERSION = "0.11.0"
VERSION = "0.11.1"
extras = {}
extras["quality"] = [

View File

@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.11.0"
__version__ = "0.11.1"
from .auto import (
AutoPeftModel,

View File

@ -20,6 +20,7 @@ from __future__ import annotations
import math
import os
import warnings
from contextlib import contextmanager
from typing import Any, Optional, Union
import torch
@ -31,13 +32,46 @@ from torch.utils.cpp_extension import load
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
os.environ["CC"] = "gcc"
os.environ["CXX"] = "gcc"
curr_dir = os.path.dirname(__file__)
_FBD_CUDA = None
# this function is a 1:1 copy from accelerate
@contextmanager
def patch_environment(**kwargs):
"""
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
Will convert the values in `kwargs` to strings and upper-case all the keys.
Example:
```python
>>> import os
>>> from accelerate.utils import patch_environment
>>> with patch_environment(FOO="bar"):
... print(os.environ["FOO"]) # prints "bar"
>>> print(os.environ["FOO"]) # raises KeyError
```
"""
existing_vars = {}
for key, value in kwargs.items():
key = key.upper()
if key in os.environ:
existing_vars[key] = os.environ[key]
os.environ[key] = str(value)
yield
for key in kwargs:
key = key.upper()
if key in existing_vars:
# restore previous value
os.environ[key] = existing_vars[key]
else:
os.environ.pop(key, None)
def get_fbd_cuda():
global _FBD_CUDA
@ -47,14 +81,15 @@ def get_fbd_cuda():
curr_dir = os.path.dirname(__file__)
# need ninja to build the extension
try:
fbd_cuda = load(
name="fbd_cuda",
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
verbose=True,
# build_directory='/tmp/' # for debugging
)
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
import fbd_cuda
with patch_environment(CC="gcc", CXX="gcc"):
fbd_cuda = load(
name="fbd_cuda",
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
verbose=True,
# build_directory='/tmp/' # for debugging
)
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
import fbd_cuda
except Exception as e:
warnings.warn(f"Failed to load the CUDA extension: {e}, check if ninja is available.")
warnings.warn("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process.")