mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Compare commits
11 Commits
87b90f045e
...
smangrul/f
Author | SHA1 | Date | |
---|---|---|---|
dc70837f63 | |||
8497e8351d | |||
345ce75ce9 | |||
6092803d81 | |||
bfa43640b1 | |||
a927b17da1 | |||
b7a8edb06a | |||
e54c97da8d | |||
0c248e9bc5 | |||
428636325d | |||
928b06ebae |
@ -16,13 +16,15 @@ Below is a table that summarizes the compatibility between PEFT's LoRA, [`bitsan
|
||||
|---|---|
|
||||
| Zero-1 | 🟢 |
|
||||
| Zero-2 | 🟢 |
|
||||
| Zero-3 | 🔴 |
|
||||
| Zero-3 | 🟢 |
|
||||
|
||||
For DeepSpeed Stage 3 + QLoRA, please refer to the section [Use PEFT QLoRA and DeepSpeed with ZeRO3 for finetuning large models on multiple GPUs](#use-peft-qlora-and-deepspeed-with-zero3-for-finetuning-large-models-on-multiple-gpus) below.
|
||||
|
||||
For confirming these observations, we ran the SFT (Supervised Fine-tuning) [offical example scripts](https://github.com/huggingface/trl/tree/main/examples) of the [Transformers Reinforcement Learning (TRL) library](https://github.com/huggingface/trl) using QLoRA + PEFT and the accelerate configs available [here](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs). We ran these experiments on a 2x NVIDIA T4 GPU.
|
||||
|
||||
Note DeepSpeed-Zero3 and `bitsandbytes` are currently **not** compatible.
|
||||
|
||||
# Use PEFT and DeepSpeed with ZeRO3 for finetuning large models on multiple machines and multiple nodes
|
||||
# Use PEFT and DeepSpeed with ZeRO3 for finetuning large models on multiple devices and multiple nodes
|
||||
|
||||
This section of guide will help you learn how to use our DeepSpeed [training script](https://github.com/huggingface/peft/blob/main/examples/sft/train.py) for performing SFT. You'll configure the script to do SFT (supervised fine-tuning) of Llama-70B model with LoRA and ZeRO-3 on 8xH100 80GB GPUs on a single machine. You can configure it to scale to multiple machines by changing the accelerate config.
|
||||
|
||||
@ -171,6 +173,115 @@ In the above example, the memory consumed per GPU is 64 GB (80%) as seen in the
|
||||
## More resources
|
||||
You can also refer this blog post [Falcon 180B Finetuning using 🤗 PEFT and DeepSpeed](https://medium.com/@sourabmangrulkar/falcon-180b-finetuning-using-peft-and-deepspeed-b92643091d99) on how to finetune 180B Falcon model on 16 A100 GPUs on 2 machines.
|
||||
|
||||
|
||||
# Use PEFT QLoRA and DeepSpeed with ZeRO3 for finetuning large models on multiple GPUs
|
||||
|
||||
In this section, we will look at how to use QLoRA and DeepSpeed Stage-3 for finetuning 70B llama model on 2X40GB GPUs.
|
||||
For this, we first need `bitsandbytes>=0.43.0`, `accelerate>=0.28.0`, `transformers>4.38.2`, `trl>0.7.11` and `peft>0.9.0`. We need to set `zero3_init_flag` to true when using Accelerate config. Below is the config which can be found at [deepspeed_config_z3_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/deepspeed_config_z3_qlora.yaml):
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
Launch command is given below which is available at [run_peft_qlora_deepspeed_stage3.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_deepspeed.sh):
|
||||
```
|
||||
accelerate launch --config_file "configs/deepspeed_config_z3_qlora.yaml" train.py \
|
||||
--seed 100 \
|
||||
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
|
||||
--dataset_name "smangrul/ultrachat-10k-chatml" \
|
||||
--chat_template_format "chatml" \
|
||||
--add_special_tokens False \
|
||||
--append_concat_token False \
|
||||
--splits "train,test" \
|
||||
--max_seq_len 2048 \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 5 \
|
||||
--log_level "info" \
|
||||
--logging_strategy "steps" \
|
||||
--evaluation_strategy "epoch" \
|
||||
--save_strategy "epoch" \
|
||||
--push_to_hub \
|
||||
--hub_private_repo True \
|
||||
--hub_strategy "every_save" \
|
||||
--bf16 True \
|
||||
--packing True \
|
||||
--learning_rate 1e-4 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--weight_decay 1e-4 \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--output_dir "llama-sft-qlora-dsz3" \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing True \
|
||||
--use_reentrant True \
|
||||
--dataset_text_field "content" \
|
||||
--use_flash_attn True \
|
||||
--use_peft_lora True \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.1 \
|
||||
--lora_target_modules "all-linear" \
|
||||
--use_4bit_quantization True \
|
||||
--use_nested_quant True \
|
||||
--bnb_4bit_compute_dtype "bfloat16" \
|
||||
--bnb_4bit_quant_storage_dtype "bfloat16"
|
||||
```
|
||||
|
||||
Notice the new argument being passed `bnb_4bit_quant_storage_dtype` which denotes the data type for packing the 4-bit parameters. For example, when it is set to `bfloat16`, **32/4 = 8** 4-bit params are packed together post quantization.
|
||||
|
||||
In terms of training code, the important code changes are:
|
||||
|
||||
```diff
|
||||
...
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=args.use_4bit_quantization,
|
||||
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=args.use_nested_quant,
|
||||
+ bnb_4bit_quant_storage=quant_storage_dtype,
|
||||
)
|
||||
|
||||
...
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
quantization_config=bnb_config,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
|
||||
+ torch_dtype=quant_storage_dtype or torch.float32,
|
||||
)
|
||||
```
|
||||
|
||||
Notice that `torch_dtype` for `AutoModelForCausalLM` is same as the `bnb_4bit_quant_storage` data type. That's it. Everything else is handled by Trainer and TRL.
|
||||
|
||||
## Memory usage
|
||||
|
||||
In the above example, the memory consumed per GPU is **36.6 GB**. Therefore, what took 8X80GB GPUs with DeepSpeed Stage 3+LoRA and a couple of 80GB GPUs with DDP+QLoRA now requires 2X40GB GPUs. This makes finetuning of large models more accessible.
|
||||
|
||||
# Use PEFT and DeepSpeed with ZeRO3 and CPU Offloading for finetuning large models on a single GPU
|
||||
This section of guide will help you learn how to use our DeepSpeed [training script](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq_accelerate_ds_zero3_offload.py). You'll configure the script to train a large model for conditional generation with ZeRO-3 and CPU Offload.
|
||||
|
||||
@ -335,3 +446,4 @@ dataset['train'][label_column][:10]=['no complaint', 'no complaint', 'complaint'
|
||||
# Caveats
|
||||
1. Merging when using PEFT and DeepSpeed is currently unsupported and will raise error.
|
||||
2. When using CPU offloading, the major gains from using PEFT to shrink the optimizer states and gradients to that of the adapter weights would be realized on CPU RAM and there won't be savings with respect to GPU memory.
|
||||
3. DeepSpeed Stage 3 and qlora when used with CPU offloading leads to more GPU memory usage when compared to disabling CPU offloading.
|
||||
|
@ -169,6 +169,117 @@ In the above example, the memory consumed per GPU is 72-80 GB (90-98%) as seen
|
||||
</div>
|
||||
<small>GPU memory usage for the training run</small>
|
||||
|
||||
# Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs
|
||||
|
||||
In this section, we will look at how to use QLoRA and FSDP for finetuning 70B llama model on 2X24GB GPUs. [Answer.AI](https://www.answer.ai/) in collaboration with bitsandbytes and Hugging Face 🤗 open sourced code enabling the usage of FSDP+QLoRA and explained the whole process in their insightful blogpost [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html). This is now integrated in Hugging Face ecosystem.
|
||||
|
||||
For this, we first need `bitsandbytes>=0.43.0`, `accelerate>=0.28.0`, `transformers>4.38.2`, `trl>0.7.11` and `peft>0.9.0`. We need to set `fsdp_cpu_ram_efficient_loading=true`, `fsdp_use_orig_params=false` and `fsdp_offload_params=true`(cpu offloading) when using Accelerate config. When not using accelerate launcher, you can alternately set the environment variable `export FSDP_CPU_RAM_EFFICIENT_LOADING=true`. Here, we will be using accelerate config and below is the config which can be found at [fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml):
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
Launch command is given below which is available at [run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh):
|
||||
```
|
||||
accelerate launch --config_file "configs/fsdp_config_qlora.yaml" train.py \
|
||||
--seed 100 \
|
||||
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
|
||||
--dataset_name "smangrul/ultrachat-10k-chatml" \
|
||||
--chat_template_format "chatml" \
|
||||
--add_special_tokens False \
|
||||
--append_concat_token False \
|
||||
--splits "train,test" \
|
||||
--max_seq_len 2048 \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 5 \
|
||||
--log_level "info" \
|
||||
--logging_strategy "steps" \
|
||||
--evaluation_strategy "epoch" \
|
||||
--save_strategy "epoch" \
|
||||
--push_to_hub \
|
||||
--hub_private_repo True \
|
||||
--hub_strategy "every_save" \
|
||||
--bf16 True \
|
||||
--packing True \
|
||||
--learning_rate 1e-4 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--weight_decay 1e-4 \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--output_dir "llama-sft-qlora-fsdp" \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing True \
|
||||
--use_reentrant True \
|
||||
--dataset_text_field "content" \
|
||||
--use_flash_attn True \
|
||||
--use_peft_lora True \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.1 \
|
||||
--lora_target_modules "all-linear" \
|
||||
--use_4bit_quantization True \
|
||||
--use_nested_quant True \
|
||||
--bnb_4bit_compute_dtype "bfloat16" \
|
||||
--bnb_4bit_quant_storage_dtype "bfloat16"
|
||||
```
|
||||
|
||||
Notice the new argument being passed, `bnb_4bit_quant_storage_dtype`, which denotes the data type for packing the 4-bit parameters. For example, when it is set to `bfloat16`, **32/4 = 8** 4-bit params are packed together post quantization. When using mixed precision training with `bfloat16`, `bnb_4bit_quant_storage_dtype` can be either `bfloat16` for pure `bfloat16` finetuning, or `float32` for automatic mixed precision (this consumes more GPU memory). When using mixed precision training with `float16`, `bnb_4bit_quant_storage_dtype` should be set to `float32` for stable automatic mixed precision training.
|
||||
|
||||
In terms of training code, the important code changes are:
|
||||
|
||||
```diff
|
||||
...
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=args.use_4bit_quantization,
|
||||
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=args.use_nested_quant,
|
||||
+ bnb_4bit_quant_storage=quant_storage_dtype,
|
||||
)
|
||||
|
||||
...
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
quantization_config=bnb_config,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
|
||||
+ torch_dtype=quant_storage_dtype or torch.float32,
|
||||
)
|
||||
```
|
||||
|
||||
Notice that `torch_dtype` for `AutoModelForCausalLM` is same as the `bnb_4bit_quant_storage` data type. That's it. Everything else is handled by Trainer and TRL.
|
||||
|
||||
## Memory usage
|
||||
|
||||
In the above example, the memory consumed per GPU is **19.6 GB** while CPU RAM usage is around **107 GB**. When disabling CPU offloading, the GPU memory usage is **35.6 GB/ GPU**. Therefore, what took 16X80GB GPUs for full finetuning, 8X80GB GPUs with FSDP+LoRA, and a couple of 80GB GPUs with DDP+QLoRA, now requires 2X24GB GPUs. This makes finetuning of large models more accessible.
|
||||
|
||||
## More resources
|
||||
You can also refer the [llama-recipes](https://github.com/facebookresearch/llama-recipes/?tab=readme-ov-file#fine-tuning) repo and [Getting started with Llama](https://llama.meta.com/get-started/#fine-tuning) guide on how to finetune using FSDP and PEFT.
|
||||
@ -176,4 +287,5 @@ You can also refer the [llama-recipes](https://github.com/facebookresearch/llama
|
||||
## Caveats
|
||||
1. Merging when using PEFT and FSDP is currently unsupported and will raise error.
|
||||
2. Passing `modules_to_save` config parameter to is untested at present.
|
||||
3. GPU Memory saving when using CPU Offloading is untested at present.
|
||||
3. GPU Memory saving when using CPU Offloading is untested at present.
|
||||
4. When using FSDP+QLoRA, `paged_adamw_8bit` currently results in an error when saving a checkpoint.
|
@ -23,10 +23,10 @@ Note:
|
||||
1. At present, `use_reentrant` needs to be `False` when using gradient checkpointing with Multi-GPU QLoRA else it will lead to errors. However, this leads to huge GPU memory consumption.
|
||||
|
||||
## Multi-GPU SFT with LoRA and DeepSpeed
|
||||
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. TO use LoRA with DeepSpeed, refer the docs at [PEFT with DeepSpeed](https://huggingface.co/docs/peft/accelerate/deepspeed).
|
||||
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. To use LoRA with DeepSpeed, refer the docs at [PEFT with DeepSpeed](https://huggingface.co/docs/peft/accelerate/deepspeed).
|
||||
|
||||
|
||||
## Multi-GPU SFT with LoRA and FSDP
|
||||
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. TO use LoRA with DeepSpeed, refer the docs at [PEFT with FSDP](https://huggingface.co/docs/peft/accelerate/fsdp).
|
||||
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. To use LoRA with DeepSpeed, refer the docs at [PEFT with FSDP](https://huggingface.co/docs/peft/accelerate/fsdp).
|
||||
|
||||
|
||||
|
22
examples/sft/configs/deepspeed_config_z3_qlora.yaml
Normal file
22
examples/sft/configs/deepspeed_config_z3_qlora.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
25
examples/sft/configs/fsdp_config_qlora.yaml
Normal file
25
examples/sft/configs/fsdp_config_qlora.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
42
examples/sft/run_peft_qlora_deepspeed_stage3.sh
Normal file
42
examples/sft/run_peft_qlora_deepspeed_stage3.sh
Normal file
@ -0,0 +1,42 @@
|
||||
accelerate launch --config_file "configs/deepspeed_config_z3_qlora.yaml" train.py \
|
||||
--seed 100 \
|
||||
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
|
||||
--dataset_name "smangrul/ultrachat-10k-chatml" \
|
||||
--chat_template_format "chatml" \
|
||||
--add_special_tokens False \
|
||||
--append_concat_token False \
|
||||
--splits "train,test" \
|
||||
--max_seq_len 2048 \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 5 \
|
||||
--log_level "info" \
|
||||
--logging_strategy "steps" \
|
||||
--evaluation_strategy "epoch" \
|
||||
--save_strategy "epoch" \
|
||||
--push_to_hub \
|
||||
--hub_private_repo True \
|
||||
--hub_strategy "every_save" \
|
||||
--bf16 True \
|
||||
--packing True \
|
||||
--learning_rate 1e-4 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--weight_decay 1e-4 \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--output_dir "llama-sft-qlora-dsz3" \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing True \
|
||||
--use_reentrant True \
|
||||
--dataset_text_field "content" \
|
||||
--use_flash_attn True \
|
||||
--use_peft_lora True \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.1 \
|
||||
--lora_target_modules "all-linear" \
|
||||
--use_4bit_quantization True \
|
||||
--use_nested_quant True \
|
||||
--bnb_4bit_compute_dtype "bfloat16" \
|
||||
--bnb_4bit_quant_storage_dtype "bfloat16"
|
42
examples/sft/run_peft_qlora_fsdp.sh
Normal file
42
examples/sft/run_peft_qlora_fsdp.sh
Normal file
@ -0,0 +1,42 @@
|
||||
accelerate launch --config_file "configs/fsdp_config_qlora.yaml" train.py \
|
||||
--seed 100 \
|
||||
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
|
||||
--dataset_name "smangrul/ultrachat-10k-chatml" \
|
||||
--chat_template_format "chatml" \
|
||||
--add_special_tokens False \
|
||||
--append_concat_token False \
|
||||
--splits "train,test" \
|
||||
--max_seq_len 2048 \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 5 \
|
||||
--log_level "info" \
|
||||
--logging_strategy "steps" \
|
||||
--evaluation_strategy "epoch" \
|
||||
--save_strategy "epoch" \
|
||||
--push_to_hub \
|
||||
--hub_private_repo True \
|
||||
--hub_strategy "every_save" \
|
||||
--bf16 True \
|
||||
--packing True \
|
||||
--learning_rate 1e-4 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--weight_decay 1e-4 \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--output_dir "llama-sft-qlora-fsdp" \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing True \
|
||||
--use_reentrant True \
|
||||
--dataset_text_field "content" \
|
||||
--use_flash_attn True \
|
||||
--use_peft_lora True \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.1 \
|
||||
--lora_target_modules "all-linear" \
|
||||
--use_4bit_quantization True \
|
||||
--use_nested_quant True \
|
||||
--bnb_4bit_compute_dtype "bfloat16" \
|
||||
--bnb_4bit_quant_storage_dtype "bfloat16"
|
@ -39,6 +39,10 @@ class ModelArguments:
|
||||
default="float16",
|
||||
metadata={"help": "Compute dtype for 4bit base models"},
|
||||
)
|
||||
bnb_4bit_quant_storage_dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={"help": "Quantization storage dtype for 4bit base models"},
|
||||
)
|
||||
bnb_4bit_quant_type: Optional[str] = field(
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization type fp4 or nf4"},
|
||||
@ -133,14 +137,7 @@ def main(model_args, data_args, training_args):
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
)
|
||||
trainer.accelerator.print(f"{trainer.model}")
|
||||
if model_args.use_peft_lora:
|
||||
# handle PEFT+FSDP case
|
||||
trainer.model.print_trainable_parameters()
|
||||
if getattr(trainer.accelerator.state, "fsdp_plugin", None):
|
||||
from peft.utils.other import fsdp_auto_wrap_policy
|
||||
|
||||
fsdp_plugin = trainer.accelerator.state.fsdp_plugin
|
||||
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)
|
||||
trainer.model.print_trainable_parameters()
|
||||
|
||||
# train
|
||||
checkpoint = None
|
||||
|
@ -84,8 +84,8 @@ def create_datasets(tokenizer, data_args, training_args, apply_chat_template=Fal
|
||||
def create_and_prepare_model(args, data_args, training_args):
|
||||
if args.use_unsloth:
|
||||
from unsloth import FastLanguageModel
|
||||
device_map = None
|
||||
bnb_config = None
|
||||
quant_storage_dtype = None
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
@ -97,12 +97,14 @@ def create_and_prepare_model(args, data_args, training_args):
|
||||
|
||||
if args.use_4bit_quantization:
|
||||
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
|
||||
quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype)
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=args.use_4bit_quantization,
|
||||
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=args.use_nested_quant,
|
||||
bnb_4bit_quant_storage=quant_storage_dtype,
|
||||
)
|
||||
|
||||
if compute_dtype == torch.float16 and args.use_4bit_quantization:
|
||||
@ -114,13 +116,6 @@ def create_and_prepare_model(args, data_args, training_args):
|
||||
elif args.use_8bit_quantization:
|
||||
bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization)
|
||||
|
||||
if args.use_4bit_quantization or args.use_8bit_quantization:
|
||||
device_map = (
|
||||
int(os.environ.get("LOCAL_RANK", -1))
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
else "auto"
|
||||
) # {"": 0}
|
||||
|
||||
if args.use_unsloth:
|
||||
# Load model
|
||||
model, _ = FastLanguageModel.from_pretrained(
|
||||
@ -133,9 +128,9 @@ def create_and_prepare_model(args, data_args, training_args):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
quantization_config=bnb_config,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
|
||||
torch_dtype=quant_storage_dtype or torch.float32,
|
||||
)
|
||||
|
||||
peft_config = None
|
||||
|
@ -505,7 +505,8 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
# one needs to multiply the number of parameters by 2 to get
|
||||
# the correct number of parameters
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
num_params = num_params * 2
|
||||
num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1
|
||||
num_params = num_params * 2 * num_bytes
|
||||
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
|
@ -103,7 +103,9 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad
|
||||
if not is_gptq_quantized and not is_aqlm_quantized:
|
||||
# cast all non INT8 parameters to fp32
|
||||
for param in model.parameters():
|
||||
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||
if (
|
||||
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||
) and param.__class__.__name__ != "Params4bit":
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing:
|
||||
|
Reference in New Issue
Block a user