mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
FEAT Add LoRA INC support (#2499)
Add LoRA Adds Intel Neural Compressor. --------- Signed-off-by: Daniel Socek <daniel.socek@intel.com>
This commit is contained in:
@ -192,7 +192,7 @@ model = get_peft_model(model, config)
|
||||
|
||||
## HQQ quantization
|
||||
|
||||
The models that is quantized using Half-Quadratic Quantization of Large Machine Learning Models ([HQQ](https://mobiusml.github.io/hqq_blog/)) support LoRA adapter tuning. To tune the quantized model, you'll need to install the `hqq` library with: `pip install hqq`.
|
||||
The models that are quantized using Half-Quadratic Quantization of Large Machine Learning Models ([HQQ](https://mobiusml.github.io/hqq_blog/)) support LoRA adapter tuning. To tune the quantized model, you'll need to install the `hqq` library with: `pip install hqq`.
|
||||
|
||||
```python
|
||||
from hqq.engine.hf import HQQModelForCausalLM
|
||||
@ -237,6 +237,45 @@ model = get_peft_model(base_model, peft_config)
|
||||
- DoRA only works with `quant_type = "int8_weight_only"` at the moment.
|
||||
- There is explicit support for torchao when used with LoRA. However, when torchao quantizes a layer, its class does not change, only the type of the underlying tensor. For this reason, PEFT methods other than LoRA will generally also work with torchao, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA and with `quant_type = "int8_weight_only"`**. If you use a different PEFT method or dtype, merging will likely result in an error, and even it doesn't, the results will still be incorrect.
|
||||
|
||||
## INC quantization
|
||||
|
||||
Intel Neural Compressor ([INC](https://github.com/intel/neural-compressor)) enables model quantization for various devices,
|
||||
including Intel Gaudi accelerators (also known as HPU devices). You can perform LoRA fine-tuning on models that have been
|
||||
quantized using INC. To use INC with PyTorch models, install the library with: `pip install neural-compressor[pt]`.
|
||||
Quantizing a model to FP8 precision for HPU devices can be done with the following single-step quantization workflow:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare
|
||||
quant_configs = {
|
||||
...
|
||||
}
|
||||
config = FP8Config(**quant_configs)
|
||||
```
|
||||
|
||||
Pass the config to the `prepare` method, run inference to gather calibration stats, and call `finalize_calibration`
|
||||
and `convert` methods to quantize model to FP8 precision:
|
||||
|
||||
```python
|
||||
model = prepare(model, config)
|
||||
# Run inference to collect calibration statistics
|
||||
...
|
||||
# Finalize calibration and convert the model to FP8 precision
|
||||
finalize_calibration(model)
|
||||
model = convert(model)
|
||||
# Load PEFT LoRA adapter as usual
|
||||
...
|
||||
```
|
||||
|
||||
An example demonstrating how to load a PEFT LoRA adapter into an INC-quantized FLUX text-to-image model for HPU
|
||||
devices is provided [here](https://github.com/huggingface/peft/blob/main/examples/stable_diffusion/inc_flux_lora_hpu.py).
|
||||
|
||||
|
||||
### Caveats:
|
||||
|
||||
- `merge()` and `unmerge()` methods are currently not supported for INC-quantized models.
|
||||
- Currently, only **Linear** INC-quantized layers are supported when loading PEFT adapters.
|
||||
|
||||
## Other Supported PEFT Methods
|
||||
|
||||
Besides LoRA, the following PEFT methods also support quantization:
|
||||
|
67
examples/stable_diffusion/inc_flux_lora_hpu.py
Normal file
67
examples/stable_diffusion/inc_flux_lora_hpu.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
This exampe demonstrates loading of LoRA adapter (via PEFT) into an FP8 INC-quantized FLUX model.
|
||||
|
||||
More info on Intel Neural Compressor (INC) FP8 quantization is available at:
|
||||
https://github.com/intel/neural-compressor/tree/master/examples/helloworld/fp8_example
|
||||
|
||||
Requirements:
|
||||
pip install optimum-habana sentencepiece neural-compressor[pt] peft
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare
|
||||
|
||||
|
||||
# Checks if HPU device is available
|
||||
# Adapted from https://github.com/huggingface/accelerate/blob/b451956fd69a135efc283aadaa478f0d33fcbe6a/src/accelerate/utils/imports.py#L435
|
||||
def is_hpu_available():
|
||||
if (
|
||||
importlib.util.find_spec("habana_frameworks") is None
|
||||
or importlib.util.find_spec("habana_frameworks.torch") is None
|
||||
):
|
||||
return False
|
||||
|
||||
import habana_frameworks.torch # noqa: F401
|
||||
|
||||
return hasattr(torch, "hpu") and torch.hpu.is_available()
|
||||
|
||||
|
||||
# Ensure HPU device is available before proceeding
|
||||
if is_hpu_available():
|
||||
from optimum.habana.diffusers import GaudiFluxPipeline
|
||||
else:
|
||||
raise RuntimeError("HPU device not found. This code requires Intel Gaudi device to run.")
|
||||
|
||||
# Example: FLUX model inference on HPU via optimum-habana pipeline
|
||||
hpu_configs = {
|
||||
"use_habana": True,
|
||||
"use_hpu_graphs": True,
|
||||
"sdp_on_bf16": True,
|
||||
"gaudi_config": "Habana/stable-diffusion",
|
||||
}
|
||||
pipe = GaudiFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **hpu_configs)
|
||||
prompt = "A picture of sks dog in a bucket"
|
||||
|
||||
# Quantize FLUX transformer to FP8 using INC (Intel Neural Compressor)
|
||||
quant_configs = {
|
||||
"mode": "AUTO",
|
||||
"observer": "maxabs",
|
||||
"scale_method": "maxabs_hw",
|
||||
"allowlist": {"types": [], "names": []},
|
||||
"blocklist": {"types": [], "names": []},
|
||||
"dump_stats_path": "/tmp/hqt_output/measure",
|
||||
}
|
||||
config = FP8Config(**quant_configs)
|
||||
pipe.transformer = prepare(pipe.transformer, config)
|
||||
pipe(prompt)
|
||||
finalize_calibration(pipe.transformer)
|
||||
pipe.transformer = convert(pipe.transformer)
|
||||
|
||||
# Load LoRA weights with PEFT
|
||||
pipe.load_lora_weights("dsocek/lora-flux-dog", adapter_name="user_lora")
|
||||
|
||||
# Run inference
|
||||
image = pipe(prompt).images[0]
|
||||
image.save("dog.png")
|
@ -118,6 +118,11 @@ def is_hqq_available():
|
||||
return importlib.util.find_spec("hqq") is not None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_inc_available():
|
||||
return importlib.util.find_spec("neural_compressor") is not None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torchao_available():
|
||||
if importlib.util.find_spec("torchao") is None:
|
||||
|
78
src/peft/tuners/lora/inc.py
Normal file
78
src/peft/tuners/lora/inc.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE: PEFT tests related to INC are handled under Optimum-Habana repository:
|
||||
# - LLMs: https://github.com/huggingface/optimum-habana/blob/main/tests/test_peft_inference.py
|
||||
# - Diffusers: https://github.com/huggingface/optimum-habana/blob/main/tests/test_diffusers.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_inc_available
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from .layer import Linear
|
||||
|
||||
|
||||
if is_inc_available():
|
||||
|
||||
class IncLoraLinear(Linear):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(base_layer, adapter_name, **kwargs)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
raise NotImplementedError("Merging LoRA with INC layers is not yet implemented")
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
raise NotImplementedError("Unmerging LoRA from INC layers is not yet implemented")
|
||||
|
||||
|
||||
def dispatch_inc(target: torch.nn.Module, adapter_name: str, **kwargs):
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if is_inc_available():
|
||||
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import (
|
||||
PatchedLinear,
|
||||
)
|
||||
|
||||
if isinstance(target_base_layer, PatchedLinear):
|
||||
new_module = IncLoraLinear(target, adapter_name, **kwargs)
|
||||
|
||||
return new_module
|
@ -145,6 +145,9 @@ class LoraLayer(BaseTunerLayer):
|
||||
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
|
||||
# HQQ layers
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif base_layer.__class__.__name__ == "PatchedLinear":
|
||||
# INC layers
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
else:
|
||||
# possibly support user provided custom layer types using dynamic dispatch
|
||||
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
|
||||
|
@ -52,6 +52,7 @@ from .config import LoraConfig
|
||||
from .eetq import dispatch_eetq
|
||||
from .gptq import dispatch_gptq
|
||||
from .hqq import dispatch_hqq
|
||||
from .inc import dispatch_inc
|
||||
from .layer import Conv2d, LoraLayer, dispatch_default
|
||||
from .torchao import dispatch_torchao
|
||||
from .tp_layer import dispatch_megatron
|
||||
@ -331,6 +332,7 @@ class LoraModel(BaseTuner):
|
||||
dispatch_awq,
|
||||
dispatch_gptq,
|
||||
dispatch_hqq,
|
||||
dispatch_inc,
|
||||
dispatch_torchao,
|
||||
dispatch_megatron,
|
||||
dispatch_default,
|
||||
|
Reference in New Issue
Block a user