TST: torch compile tests (#1725)

Right now, we don't have specific tests for torch.compile. Instead, we
have a "hack" that allows to run _all_ tests with torch.compile if we
set the environment variable PEFT_DEBUG_WITH_TORCH_COMPILE=1.

This is not very practical because it takes a lot of time to run all
these tests with compilation enabled. Also, currently hundreds of tests
are failing, which makes it impossible to understand more closely what
does or does not work.

This PR removes the aforementioned "hack" and instead replaces it with a
list of explicit torch.compile tests. Currently, these tests cover
training/inference with a bunch of different tuner types, as well as
more advanced features with LoRA (e.g. quantization, multiple adapters,
etc.).

Some of these tests pass and some of them fail. This is documented now,
so that users can quickly look up if their use case would be compatible
with torch.compile. This is very useful to have, because sometimes
torch.compile may appear to work but actually returns the wrong result.
For users, it's not immediately obvious when this happens.

The test suite is not exhaustive, there are many combinations of
features that could be added. However, it should be a good starting
point and can be extended later.

The test suite does _not_ cover whether torch.compile actually
accelerates the code. This may not be the case even if it works
correctly (e.g. because of graph breaks). Testing this would require
bigger models and more data, which is prohibitively slow to test.

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Benjamin Bossan
2024-05-17 18:03:27 +02:00
committed by GitHub
parent 3f7aacd601
commit 4e32679f37
6 changed files with 648 additions and 25 deletions

View File

@ -40,4 +40,4 @@ jobs:
run: |
echo "PEFT_DEBUG_WITH_TORCH_COMPILE=$PEFT_DEBUG_WITH_TORCH_COMPILE"
git status
make test
make tests_torch_compile

View File

@ -53,3 +53,6 @@ transformers_tests:
tests_regression:
python -m pytest -s --regression tests/regression/ $(if $(IS_GITHUB_CI),--report-log "regression_tests.log",)
tests_torch_compile:
python -m pytest tests/test_torch_compile.py $(if $(IS_GITHUB_CI),--report-log "compile_tests.log",)

View File

@ -37,6 +37,8 @@
title: Adapter injection
- local: developer_guides/mixed_models
title: Mixed adapter types
- local: developer_guides/torch_compile
title: torch.compile
- local: developer_guides/contributing
title: Contribute to PEFT
- local: developer_guides/troubleshooting

View File

@ -0,0 +1,75 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# torch.compile
In PEFT, [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) works for some but not all features. The reason why it won't always work is because PEFT is highly dynamic in certain places (loading and switching between multiple adapters, for instance), which can cause trouble for `torch.compile`. In other places, `torch.compile` may work, but won't be as fast as expected because of graph breaks.
If you don't see an error, it doesn't necessarily mean that `torch.compile` worked correctly. It might give you an output, but the output is incorrect. This guide describes what works with `torch.compile` and what doesn't.
> [!TIP]
> Unless indicated otherwise, the default `torch.compile` settings were used.
## Training and inference with `torch.compile`
These features **work** with `torch.compile`. Everything listed below was tested with a causal LM:
- Training with `Trainer` from 🤗 transformers
- Training with a custom PyTorch loop
- Inference
- Generation
The following adapters were tested successfully:
- AdaLoRA
- BOFT
- IA³
- Layer Norm Tuning
- LoHa
- LoRA
- LoRA + DoRA
- OFT
- VeRA
The following adapters **don't work** correctly for training or inference when using `torch.compile`:
- LoKr
- LoRA targeting embedding layers
## Advanced PEFT features with `torch.compile`
Below are some of the more advanced PEFT features that **work**. They were all tested with LoRA.
- `modules_to_save` (i.e. `config = LoraConfig(..., modules_to_save=...)`)
- Merging adapters (one or multiple)
- Merging multiple adapters into one adapter (i.e. calling `model.add_weighted_adapter(...)`)
Generally, we can expect that if a feature works correctly with LoRA and is also supported by other adapter types, it should also work for that adapter type.
The more advanced PEFT features below **don't work** in conjunction with `torch.compile`. Tests were run with LoRA:
- Using PEFT adapters with quantization (bitsandbytes)
- Inference with multiple adapters
- Unloading (i.e. calling `model.merge_and_unload()`)
- Disabling adapters (i.e. using `with model.disable_adapter()`)
- Mixed adapter batches (i.e. calling `model(batch, adapter_names=["__base__", "default", "other", ...])`)
## Test cases
All the use cases listed above are tested inside of [`peft/tests/test_torch_compile.py`](https://github.com/huggingface/peft/blob/main/tests/test_torch_compile.py). If you want to check in more detail how we tested a certain feature, please go to that file and check the test that corresponds to your use case.
> [!TIP]
> If you have another use case where you know that `torch.compile` does or does not work with PEFT, please contribute by letting us know or by opening a PR to add this use case to the covered test cases.

View File

@ -1,24 +0,0 @@
import os
if os.environ.get("PEFT_DEBUG_WITH_TORCH_COMPILE") == "1":
# This is a hack purely for debugging purposes. If the environment variable PEFT_DEBUG_WITH_TORCH_COMPILE is set to
# 1, get_peft_model() will return a compiled model. This way, all unit tests that use peft.get_peft_model() will
# use a compiled model. See .github/workflows/torch_compile_tests.yml.
import torch
import peft
from peft.mapping import get_peft_model as get_peft_model_original
# TODO: Experimental dynamo feature that should allow correct compilation of more PEFT modules. This should be
# removed once PyTorch has found a better solution, as this incurs a performance penalty.
# https://github.com/pytorch/pytorch/issues/124717#issuecomment-2083235776
torch._dynamo.config.guard_nn_modules = True
def get_peft_model_new(*args, **kwargs):
"""Make get_peft_model() return a compiled model."""
peft_model = get_peft_model_original(*args, **kwargs)
peft_model = torch.compile(peft_model)
return peft_model
peft.get_peft_model = get_peft_model_new

567
tests/test_torch_compile.py Normal file
View File

@ -0,0 +1,567 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The intent of the tests contained in this file is to check as many PEFT features as possible with torch.compile. This
# is thus a document on how well torch.compile is supported by PEFT. Currently, we know that certain features do not
# work with torch.compile. The corresponding tests should be marked with `@pytest.mark.xfail(strict=True)`.
#
# When adding a new test that fails with torch.compile, please make sure first that it does NOT fail without
# torch.compile.
import gc
import os
import pytest
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import (
AdaLoraConfig,
BOFTConfig,
IA3Config,
LNTuningConfig,
LoHaConfig,
LoKrConfig,
LoraConfig,
OFTConfig,
PeftModel,
TaskType,
VeraConfig,
get_peft_model,
)
# only run (very slow) torch.compile tests when explicitly asked to
if os.environ.get("PEFT_DEBUG_WITH_TORCH_COMPILE") != "1":
pytest.skip(allow_module_level=True)
# Mapping: name of the setting -> (Peft config instance, torch.compile kwargs)
SETTINGS = {
"adalora": (AdaLoraConfig(task_type=TaskType.CAUSAL_LM), {}),
"boft": (BOFTConfig(task_type=TaskType.CAUSAL_LM), {}),
"dora": (LoraConfig(task_type=TaskType.CAUSAL_LM, use_dora=True), {}),
"ia3": (IA3Config(task_type=TaskType.CAUSAL_LM), {}),
"ln_tuning": (LNTuningConfig(task_type=TaskType.CAUSAL_LM, target_modules=["final_layer_norm"]), {}),
"loha": (LoHaConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}),
"lokr": pytest.param(
(LoKrConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}),
marks=pytest.mark.xfail(strict=True),
),
"lora": (LoraConfig(task_type=TaskType.CAUSAL_LM), {}),
"lora-target-embeddings": pytest.param(
(LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["embed_tokens"]), {}),
marks=pytest.mark.xfail(strict=True),
),
"lora-with-modules-to-save": (LoraConfig(task_type=TaskType.CAUSAL_LM, modules_to_save=["embed_tokens"]), {}),
"oft": (OFTConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}),
"vera": (VeraConfig(task_type=TaskType.CAUSAL_LM), {}),
}
@pytest.mark.single_gpu_tests
class TestTorchCompileCausalLM:
"""
Tests for using torch.compile with causal LM.
Tip: When adding a new test, set `fake_compile = False` below. With this setting, torch.compile is being skipped.
This is useful for two reasons:
- compile is slow, so to quickly iterate on the test, it's best to disable it and only enable it at the very end
- even if you expect the test to fail with compile, as compile does not work with every PEFT feature, it still MUST
succeed without compile, otherwise the test is incorrect.
Before creating the PR, disable `fake_compile`.
"""
fake_compile = False
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
max_train_loss = 15.0 # generous threshold for maximum loss after training
@pytest.fixture(autouse=True)
def teardown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@pytest.fixture(scope="class")
def tokenizer(self):
return AutoTokenizer.from_pretrained(self.model_id)
@pytest.fixture(scope="class")
def data(self, tokenizer):
def tokenize(samples):
# For some reason, the max sequence length is not honored by the tokenizer, resulting in IndexErrors. Thus,
# manually ensure that sequences are not too long.
tokenized = tokenizer(samples["quote"])
tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]]
tokenized["attention_mask"] = [
input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"]
]
return tokenized
data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(tokenize, batched=True)
# We need to manually remove unused columns. This is because we cannot use remove_unused_columns=True in the
# Trainer, as this leads to errors with torch.compile. We also cannot just leave them in, as they contain
# strings. Therefore, manually remove all unused columns.
data = data.remove_columns(["quote", "author", "tags"])
return data
def compile(self, model, compile_kwargs):
compile_kwargs = compile_kwargs.copy()
# those are only for the Trainer arguments
compile_kwargs.pop("torch_compile_backend", None)
compile_kwargs.pop("torch_compile_mode", None)
if self.fake_compile:
return model
return torch.compile(model, **compile_kwargs)
@pytest.mark.parametrize("settings", SETTINGS.values(), ids=SETTINGS.keys())
def test_causal_lm_training_trainer_compile(self, settings, tokenizer, data, tmp_path):
r"""Train a PEFT model with torch.compile using Trainer"""
tmp_dir = tmp_path / "model"
config, compile_kwargs = settings
if isinstance(config, AdaLoraConfig):
pytest.skip(reason="AdaLora does not work correctly with Trainer")
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
)
model = get_peft_model(model, config)
# record outputs before training
model.eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_before = model(sample)
model.train()
train_kwargs = {
"per_device_train_batch_size": 4,
"max_steps": 5,
"learning_rate": 1e-3,
"logging_steps": 1,
"output_dir": tmp_dir,
"seed": 0,
}
training_args = TrainingArguments(
torch_compile=not self.fake_compile,
torch_compile_backend=compile_kwargs.get("torch_compile_backend", None),
torch_compile_mode=compile_kwargs.get("torch_compile_mode", None),
**train_kwargs,
)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=training_args,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.eval()
atol, rtol = 1e-4, 1e-4
with torch.inference_mode():
output_after = model(sample)
tokens_after = model.generate(sample)
assert torch.isfinite(output_after.logits).all()
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
assert trainer.state.log_history[-1]["train_loss"] < self.max_train_loss
# check saving the model and loading it without compile
model.save_pretrained(tmp_path)
del model
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto")
model = PeftModel.from_pretrained(model, tmp_path)
with torch.inference_mode():
output_loaded = model(sample)
tokens_loaded = model.generate(sample)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
assert (tokens_after == tokens_loaded).all()
@pytest.mark.parametrize("settings", SETTINGS.values(), ids=SETTINGS.keys())
def test_causal_lm_training_pytorch_compile(self, settings, tokenizer, data, tmp_path):
r"""Train a PEFT model with torch.compile using PyTorch training loop"""
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
)
config, compile_kwargs = settings
model = get_peft_model(model, config)
if isinstance(config, AdaLoraConfig):
model.base_model.peft_config["default"].total_step = 5
model = self.compile(model, compile_kwargs)
# record outputs before training
model.eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_before = model(sample)
model.train()
model.config.use_cache = False
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
losses = []
max_steps = 5 * batch_size
for i in range(0, max_steps, batch_size):
batch = tokenizer.pad(data["train"][i : i + batch_size], return_tensors="pt").to(model.device)
# add targets
batch["labels"] = batch["input_ids"].clone()
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
losses.append(loss.item())
if isinstance(config, AdaLoraConfig):
model.base_model.update_and_allocate(i)
model.eval()
with torch.inference_mode():
output_after = model(sample)
tokens_after = model.generate(sample)
assert torch.isfinite(output_after.logits).all()
atol, rtol = 1e-4, 1e-4
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
assert losses[-1] < self.max_train_loss
# check saving the model and loading it without compile
model.save_pretrained(tmp_path)
del model
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto")
model = PeftModel.from_pretrained(model, tmp_path)
with torch.inference_mode():
output_loaded = model(sample)
tokens_loaded = model.generate(sample)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
assert (tokens_after == tokens_loaded).all()
@pytest.mark.xfail(strict=True)
def test_causal_lm_training_lora_bnb_compile(self, tokenizer, data, tmp_path):
r"""Train a bnb quantized LoRA model with torch.compile using PyTorch training loop"""
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
config = LoraConfig(task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, config)
model = self.compile(model, {})
# record outputs before training
model.eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_before = model(sample)
model.train()
model.config.use_cache = False
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
losses = []
max_steps = 5 * batch_size
for i in range(0, max_steps, batch_size):
batch = tokenizer.pad(data["train"][i : i + batch_size], return_tensors="pt").to(model.device)
# add targets
batch["labels"] = batch["input_ids"].clone()
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
losses.append(loss.item())
model.eval()
with torch.inference_mode():
output_after = model(sample)
assert torch.isfinite(output_after.logits).all()
atol, rtol = 1e-4, 1e-4
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
assert losses[-1] < self.max_train_loss
# check saving the model and loading it without compile
model.save_pretrained(tmp_path)
del model
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id, device_map="auto", quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)
model = PeftModel.from_pretrained(model, tmp_path)
with torch.inference_mode():
# after loading, outputs are float32 for some reason
output_loaded = model(sample)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_multiple_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model = self.compile(model, {})
model.add_adapter("other", config)
model = self.compile(model, {})
with torch.inference_mode():
output_default_adapter = model(sample)
model.set_adapter("other")
with torch.inference_mode():
output_other_adapter = model(sample)
atol, rtol = 1e-4, 1e-4
# outputs of the base model != output of default adapter != output of other adapter
assert not torch.allclose(output_base.logits, output_default_adapter.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_other_adapter.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default_adapter.logits, output_other_adapter.logits, atol=atol, rtol=rtol)
# now delete the other adapter
model.delete_adapter("other")
model.set_adapter("default")
with torch.inference_mode():
output_after_delete = model(sample)
# outputs after delete == output of default adapter
assert torch.allclose(output_default_adapter.logits, output_after_delete.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_disable_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model = self.compile(model, {})
output_lora = model(sample)
with model.disable_adapter():
with torch.inference_mode():
output_disabled = model(sample)
atol, rtol = 1e-4, 1e-4
# outputs of the base model == output disabled adapter != output of lora adapter
assert torch.allclose(output_base.logits, output_disabled.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_lora.logits, atol=atol, rtol=rtol)
def test_causal_lm_merging_lora_adapter_compile(self, tokenizer, data):
# merge the adapter
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
output_lora = model(sample)
model.merge_adapter()
with torch.inference_mode():
output_merged = model(sample)
# merging is less precise, be more tolerant
atol, rtol = 1e-1, 1e-1
# outputs of the base model != output of lora adapter == output of merged adapter
assert not torch.allclose(output_base.logits, output_lora.logits, atol=atol, rtol=rtol)
assert torch.allclose(output_lora.logits, output_merged.logits, atol=atol, rtol=rtol)
def test_causal_lm_merging_multiple_lora_adapters_compile(self, tokenizer, data):
# merge multiple adapters at once
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model.add_adapter("other", config)
with torch.inference_mode():
output_default = model(sample)
model.set_adapter("other")
with torch.inference_mode():
output_other = model(sample)
model.base_model.merge_adapter(["default", "other"])
with torch.inference_mode():
output_merged = model(sample)
# merging is less precise, be more tolerant
atol, rtol = 1e-1, 1e-1
# outputs of the base model != output of default adapter != output of other adapter
assert not torch.allclose(output_base.logits, output_default.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_other.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_other.logits, atol=atol, rtol=rtol)
# outputs of merged adapter != all others
assert not torch.allclose(output_base.logits, output_merged.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_merged.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_other.logits, output_merged.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_merge_and_unload_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model = self.compile(model, {})
with torch.inference_mode():
output_lora = model(sample)
unloaded = model.merge_and_unload()
with torch.inference_mode():
output_unloaded = unloaded(sample)
# merging is less precise, be more tolerant
atol, rtol = 1e-1, 1e-1
# outputs of the base model != output of lora adapter == output of unloaded adapter
assert not torch.allclose(output_base.logits, output_lora.logits, atol=atol, rtol=rtol)
assert torch.allclose(output_lora.logits, output_unloaded.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_mixed_batch_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
# we need at least 3 samples for this to work!
sample = {
"input_ids": torch.arange(12).reshape(3, 4).to("cuda"),
"attention_mask": torch.ones(3, 4).long().to("cuda"),
}
with torch.inference_mode():
output_base = model(**sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
output_default = model(**sample)
model.add_adapter("other", config)
model.set_adapter("other")
with torch.inference_mode():
output_other = model(**sample)
model = self.compile(model, {})
# set adapter_indices so that it alternates between 0 (base), lora 1, and lora 2
adapter_names = ["__base__", "default", "other"]
with torch.inference_mode():
output_mixed = model(**sample, adapter_names=adapter_names)
atol, rtol = 1e-4, 1e-4
# outputs of the base model != output of lora adapter 1 != output of other adapter
assert not torch.allclose(output_base.logits, output_default.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_other.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_other.logits, output_mixed.logits, atol=atol, rtol=rtol)
# outputs of mixed adapter is mix of all 3
assert torch.allclose(output_base.logits[0], output_mixed.logits[0], atol=atol, rtol=rtol)
assert torch.allclose(output_default.logits[1], output_mixed.logits[1], atol=atol, rtol=rtol)
assert torch.allclose(output_other.logits[2], output_mixed.logits[2], atol=atol, rtol=rtol)
def test_causal_lm_add_weighted_adapter_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model.add_adapter("other", config)
with torch.inference_mode():
output_default = model(sample)
model.set_adapter("other")
with torch.inference_mode():
output_other = model(sample)
model.add_weighted_adapter(["default", "other"], [0.5, 0.5], adapter_name="combined")
model.set_adapter("combined")
with torch.inference_mode():
output_combined = model(sample)
atol, rtol = 1e-4, 1e-4
# outputs of the base model != output of default adapter != output of other adapter
assert not torch.allclose(output_base.logits, output_default.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_other.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_other.logits, atol=atol, rtol=rtol)
# outputs of combined adapter != all others
assert not torch.allclose(output_base.logits, output_combined.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_combined.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_other.logits, output_combined.logits, atol=atol, rtol=rtol)