FEAT Add LoRA INC support (#2499)

Add LoRA Adds Intel Neural Compressor.

---------

Signed-off-by: Daniel Socek <daniel.socek@intel.com>
This commit is contained in:
Daniel Socek
2025-04-28 12:39:37 -04:00
committed by GitHub
parent 453a6ff336
commit 003cf20bcd
6 changed files with 195 additions and 1 deletions

View File

@ -192,7 +192,7 @@ model = get_peft_model(model, config)
## HQQ quantization
The models that is quantized using Half-Quadratic Quantization of Large Machine Learning Models ([HQQ](https://mobiusml.github.io/hqq_blog/)) support LoRA adapter tuning. To tune the quantized model, you'll need to install the `hqq` library with: `pip install hqq`.
The models that are quantized using Half-Quadratic Quantization of Large Machine Learning Models ([HQQ](https://mobiusml.github.io/hqq_blog/)) support LoRA adapter tuning. To tune the quantized model, you'll need to install the `hqq` library with: `pip install hqq`.
```python
from hqq.engine.hf import HQQModelForCausalLM
@ -237,6 +237,45 @@ model = get_peft_model(base_model, peft_config)
- DoRA only works with `quant_type = "int8_weight_only"` at the moment.
- There is explicit support for torchao when used with LoRA. However, when torchao quantizes a layer, its class does not change, only the type of the underlying tensor. For this reason, PEFT methods other than LoRA will generally also work with torchao, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA and with `quant_type = "int8_weight_only"`**. If you use a different PEFT method or dtype, merging will likely result in an error, and even it doesn't, the results will still be incorrect.
## INC quantization
Intel Neural Compressor ([INC](https://github.com/intel/neural-compressor)) enables model quantization for various devices,
including Intel Gaudi accelerators (also known as HPU devices). You can perform LoRA fine-tuning on models that have been
quantized using INC. To use INC with PyTorch models, install the library with: `pip install neural-compressor[pt]`.
Quantizing a model to FP8 precision for HPU devices can be done with the following single-step quantization workflow:
```python
import torch
from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare
quant_configs = {
...
}
config = FP8Config(**quant_configs)
```
Pass the config to the `prepare` method, run inference to gather calibration stats, and call `finalize_calibration`
and `convert` methods to quantize model to FP8 precision:
```python
model = prepare(model, config)
# Run inference to collect calibration statistics
...
# Finalize calibration and convert the model to FP8 precision
finalize_calibration(model)
model = convert(model)
# Load PEFT LoRA adapter as usual
...
```
An example demonstrating how to load a PEFT LoRA adapter into an INC-quantized FLUX text-to-image model for HPU
devices is provided [here](https://github.com/huggingface/peft/blob/main/examples/stable_diffusion/inc_flux_lora_hpu.py).
### Caveats:
- `merge()` and `unmerge()` methods are currently not supported for INC-quantized models.
- Currently, only **Linear** INC-quantized layers are supported when loading PEFT adapters.
## Other Supported PEFT Methods
Besides LoRA, the following PEFT methods also support quantization:

View File

@ -0,0 +1,67 @@
"""
This exampe demonstrates loading of LoRA adapter (via PEFT) into an FP8 INC-quantized FLUX model.
More info on Intel Neural Compressor (INC) FP8 quantization is available at:
https://github.com/intel/neural-compressor/tree/master/examples/helloworld/fp8_example
Requirements:
pip install optimum-habana sentencepiece neural-compressor[pt] peft
"""
import importlib
import torch
from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare
# Checks if HPU device is available
# Adapted from https://github.com/huggingface/accelerate/blob/b451956fd69a135efc283aadaa478f0d33fcbe6a/src/accelerate/utils/imports.py#L435
def is_hpu_available():
if (
importlib.util.find_spec("habana_frameworks") is None
or importlib.util.find_spec("habana_frameworks.torch") is None
):
return False
import habana_frameworks.torch # noqa: F401
return hasattr(torch, "hpu") and torch.hpu.is_available()
# Ensure HPU device is available before proceeding
if is_hpu_available():
from optimum.habana.diffusers import GaudiFluxPipeline
else:
raise RuntimeError("HPU device not found. This code requires Intel Gaudi device to run.")
# Example: FLUX model inference on HPU via optimum-habana pipeline
hpu_configs = {
"use_habana": True,
"use_hpu_graphs": True,
"sdp_on_bf16": True,
"gaudi_config": "Habana/stable-diffusion",
}
pipe = GaudiFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **hpu_configs)
prompt = "A picture of sks dog in a bucket"
# Quantize FLUX transformer to FP8 using INC (Intel Neural Compressor)
quant_configs = {
"mode": "AUTO",
"observer": "maxabs",
"scale_method": "maxabs_hw",
"allowlist": {"types": [], "names": []},
"blocklist": {"types": [], "names": []},
"dump_stats_path": "/tmp/hqt_output/measure",
}
config = FP8Config(**quant_configs)
pipe.transformer = prepare(pipe.transformer, config)
pipe(prompt)
finalize_calibration(pipe.transformer)
pipe.transformer = convert(pipe.transformer)
# Load LoRA weights with PEFT
pipe.load_lora_weights("dsocek/lora-flux-dog", adapter_name="user_lora")
# Run inference
image = pipe(prompt).images[0]
image.save("dog.png")

View File

@ -118,6 +118,11 @@ def is_hqq_available():
return importlib.util.find_spec("hqq") is not None
@lru_cache
def is_inc_available():
return importlib.util.find_spec("neural_compressor") is not None
@lru_cache
def is_torchao_available():
if importlib.util.find_spec("torchao") is None:

View File

@ -0,0 +1,78 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: PEFT tests related to INC are handled under Optimum-Habana repository:
# - LLMs: https://github.com/huggingface/optimum-habana/blob/main/tests/test_peft_inference.py
# - Diffusers: https://github.com/huggingface/optimum-habana/blob/main/tests/test_diffusers.py
from typing import Optional
import torch
from peft.import_utils import is_inc_available
from peft.tuners.tuners_utils import BaseTunerLayer
from .layer import Linear
if is_inc_available():
class IncLoraLinear(Linear):
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
**kwargs,
):
super().__init__(base_layer, adapter_name, **kwargs)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
raise NotImplementedError("Merging LoRA with INC layers is not yet implemented")
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
raise NotImplementedError("Unmerging LoRA from INC layers is not yet implemented")
def dispatch_inc(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if is_inc_available():
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import (
PatchedLinear,
)
if isinstance(target_base_layer, PatchedLinear):
new_module = IncLoraLinear(target, adapter_name, **kwargs)
return new_module

View File

@ -145,6 +145,9 @@ class LoraLayer(BaseTunerLayer):
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
# HQQ layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif base_layer.__class__.__name__ == "PatchedLinear":
# INC layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
# possibly support user provided custom layer types using dynamic dispatch
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):

View File

@ -52,6 +52,7 @@ from .config import LoraConfig
from .eetq import dispatch_eetq
from .gptq import dispatch_gptq
from .hqq import dispatch_hqq
from .inc import dispatch_inc
from .layer import Conv2d, LoraLayer, dispatch_default
from .torchao import dispatch_torchao
from .tp_layer import dispatch_megatron
@ -331,6 +332,7 @@ class LoraModel(BaseTuner):
dispatch_awq,
dispatch_gptq,
dispatch_hqq,
dispatch_inc,
dispatch_torchao,
dispatch_megatron,
dispatch_default,