mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
Add Qwen3-VL notebooks (SFT, GRPO) (#4275)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
cefbacb30e
commit
49c8f14b06
694
examples/notebooks/grpo_qwen3_vl.ipynb
Normal file
694
examples/notebooks/grpo_qwen3_vl.ipynb
Normal file
@ -0,0 +1,694 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-J8iGzLf4rUJ"
|
||||
},
|
||||
"source": [
|
||||
"# GRPO Qwen3-VL with QLoRA using TRL\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/huggingface/trl/blob/main/notebooks/grpo_qwen3_vl.ipynb)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can fine-tune cutting edge vision language models. It comes with support for quantized parameter efficient fine-tuning technique **QLoRA**, so we can use free Colab (T4 GPU) to fine-tune models like [Qwen3-VL](https://huggingface.co/collections/Qwen/qwen3-vl-68d2a7c1b8a8afce4ebd2dbe).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
|
||||
"- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
|
||||
"- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
|
||||
"- [More Qwen3-VL Fine-tuning Examples (including TRL scripts)](https://github.com/QwenLM/Qwen3-VL/tree/main/qwen-vl-finetune/)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "NvrzGRnu48Vz"
|
||||
},
|
||||
"source": [
|
||||
"## Install dependencies\n",
|
||||
"\n",
|
||||
"We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "8CfZlUevmkg7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -Uq \"trl[peft]\" bitsandbytes trackio math_verify"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gpzI6omi7728"
|
||||
},
|
||||
"source": [
|
||||
"### Log in to Hugging Face\n",
|
||||
"\n",
|
||||
"Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "4Ncx0wYtnYCW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "V_Zylc4t79-n"
|
||||
},
|
||||
"source": [
|
||||
"## Load dataset\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"We'll load the [**lmms-lab/multimodal-open-r1-8k-verified**](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset from the Hugging Face Hub using the `datasets` library.\n",
|
||||
"\n",
|
||||
"This dataset contains maths problems with the image representing the problem, along with the solution in thinking format specially tailored for VLMs. By training our model with this dataset, it'll improve its maths and thinking reasoning.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "TzXogU24F_QR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
|
||||
"train_dataset = load_dataset(dataset_id, split='train[:5%]')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gVV7RoRN8zk5"
|
||||
},
|
||||
"source": [
|
||||
"In addition to the `problem` and `image` columns, we also include a custom system prompt to tell the model how we'd like the generation.\n",
|
||||
"\n",
|
||||
"The system prompt is extracted from DeepSeek R1. Refer to [this previous recipe](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) for more details.\n",
|
||||
"\n",
|
||||
"We convert the dataset samples into conversation samples, including the system prompt and one image and problem description per sample, since this is how the GRPO trainer expects them.\n",
|
||||
"\n",
|
||||
"We also set `padding_side=\"left\"` to ensure that generated completions during training are concatenated directly after the prompt, which is essential for GRPO to correctly compare token-level probabilities between preferred and rejected responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZT1JfiiTGExB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoProcessor\n",
|
||||
"\n",
|
||||
"model_name = \"Qwen/Qwen3-VL-4B-Instruct\" # \"Qwen/Qwen3-VL-8B-Instruct\"\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_name, padding_side=\"left\")\n",
|
||||
"\n",
|
||||
"SYSTEM_PROMPT = (\n",
|
||||
" \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. \"\n",
|
||||
" \"You first think about the reasoning process as an internal monologue and then provide the user with the answer. \"\n",
|
||||
" \"Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_conversation(example):\n",
|
||||
" conversation = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": [{\"type\": \"text\", \"text\": SYSTEM_PROMPT}],\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": example[\"image\"]},\n",
|
||||
" {\"type\": \"text\", \"text\": example[\"problem\"]},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)\n",
|
||||
" return {\n",
|
||||
" \"prompt\": prompt,\n",
|
||||
" \"image\": example[\"image\"],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"train_dataset = train_dataset.map(make_conversation)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "5txAuMAa8ock"
|
||||
},
|
||||
"source": [
|
||||
"Let's review one example to understand the internal structure:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "PDXQd5Jk2Bqe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hzSR_56wxKDA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = train_dataset.remove_columns(['problem', 'original_question', 'original_answer'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "T9rCkeqDODba"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YY3uMp909Eqy"
|
||||
},
|
||||
"source": [
|
||||
"## Load model and configure LoRA/QLoRA\n",
|
||||
"\n",
|
||||
"This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "gt05dgXgm9QR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Qwen3VLForConditionalGeneration, BitsAndBytesConfig\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"model = Qwen3VLForConditionalGeneration.from_pretrained(\n",
|
||||
" model_name, dtype=\"auto\",\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" quantization_config=BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||
" bnb_4bit_compute_dtype=torch.float16\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WZGf-GF09Gsc"
|
||||
},
|
||||
"source": [
|
||||
"The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ME1im5gh2LFg"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from peft import LoraConfig\n",
|
||||
"\n",
|
||||
"# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
|
||||
"# For example, different VLMs might have different attention/projection layer names.\n",
|
||||
"peft_config = LoraConfig(\n",
|
||||
" r=8,\n",
|
||||
" lora_alpha=32,\n",
|
||||
" lora_dropout=0.1,\n",
|
||||
" target_modules=[\"q_proj\", \"v_proj\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "mDq4V6dN9MGk"
|
||||
},
|
||||
"source": [
|
||||
"## Train model\n",
|
||||
"\n",
|
||||
"We'll configure **GRPO** using `GRPOConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL GRPOConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.GRPOConfig).\n",
|
||||
"\n",
|
||||
"First, we need to define the rewards functions that the training algorithm will use to improve the model. In this case, we'll include two reward functions.\n",
|
||||
"We'll use a format reward that will reward the model when the output includes `<think>` and `<answer>` tags and additionally a length-based reward to discourage overthinking. Both functions have been extracted from [here](https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Dqp3TfUwHUxW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"def format_reward(completions, **kwargs):\n",
|
||||
" \"\"\"Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.\"\"\"\n",
|
||||
" pattern = r\"^<think>\\n.*?\\n</think>\\n<answer>\\n.*?\\n</answer>$\"\n",
|
||||
" matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]\n",
|
||||
" return [1.0 if match else 0.0 for match in matches]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rxNPUp7RBFcz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from math_verify import LatexExtractionConfig, parse, verify\n",
|
||||
"from latex2sympy2_extended import NormalizationConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def len_reward(completions, solution, **kwargs) -> float:\n",
|
||||
" \"\"\"Compute length-based rewards to discourage overthinking and promote token efficiency.\n",
|
||||
"\n",
|
||||
" Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" completions: List of model completions\n",
|
||||
" solution: List of ground truth solutions\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" List of rewards where:\n",
|
||||
" - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)\n",
|
||||
" - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))\n",
|
||||
" \"\"\"\n",
|
||||
" contents = completions\n",
|
||||
"\n",
|
||||
" # First check correctness of answers\n",
|
||||
" correctness = []\n",
|
||||
" for content, sol in zip(contents, solution):\n",
|
||||
" gold_parsed = parse(\n",
|
||||
" sol,\n",
|
||||
" extraction_mode=\"first_match\",\n",
|
||||
" extraction_config=[LatexExtractionConfig()],\n",
|
||||
" )\n",
|
||||
" if len(gold_parsed) == 0:\n",
|
||||
" # Skip unparseable examples\n",
|
||||
" correctness.append(True) # Treat as correct to avoid penalizing\n",
|
||||
" print(\"Failed to parse gold solution: \", sol)\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" answer_parsed = parse(\n",
|
||||
" content,\n",
|
||||
" extraction_config=[\n",
|
||||
" LatexExtractionConfig(\n",
|
||||
" normalization_config=NormalizationConfig(\n",
|
||||
" nits=False,\n",
|
||||
" malformed_operators=False,\n",
|
||||
" basic_latex=True,\n",
|
||||
" equations=True,\n",
|
||||
" boxed=True,\n",
|
||||
" units=True,\n",
|
||||
" ),\n",
|
||||
" boxed_match_priority=0,\n",
|
||||
" try_extract_without_anchor=False,\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
" extraction_mode=\"first_match\",\n",
|
||||
" )\n",
|
||||
" correctness.append(verify(answer_parsed, gold_parsed))\n",
|
||||
"\n",
|
||||
" # Calculate lengths\n",
|
||||
" lengths = [len(content) for content in contents]\n",
|
||||
" min_len = min(lengths)\n",
|
||||
" max_len = max(lengths)\n",
|
||||
"\n",
|
||||
" # If all responses have the same length, return zero rewards\n",
|
||||
" if max_len == min_len:\n",
|
||||
" return [0.0] * len(completions)\n",
|
||||
"\n",
|
||||
" rewards = []\n",
|
||||
" for length, is_correct in zip(lengths, correctness):\n",
|
||||
" lambda_val = 0.5 - (length - min_len) / (max_len - min_len)\n",
|
||||
"\n",
|
||||
" if is_correct:\n",
|
||||
" reward = lambda_val\n",
|
||||
" else:\n",
|
||||
" reward = min(0, lambda_val)\n",
|
||||
"\n",
|
||||
" rewards.append(float(reward))\n",
|
||||
"\n",
|
||||
" return rewards\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9xBL7Rni9LZb"
|
||||
},
|
||||
"source": [
|
||||
"After defining the reward function(s), we can define the `GRPOConfig`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "OEmRM0rIHXQ4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import GRPOConfig\n",
|
||||
"\n",
|
||||
"output_dir = \"Qwen3-VL-4B-Instruct-trl-grpo\"\n",
|
||||
"\n",
|
||||
"# Configure training arguments using GRPOConfig\n",
|
||||
"training_args = GRPOConfig(\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" #num_train_epochs=1,\n",
|
||||
" max_steps=100, # Number of dataset passes. For full trainings, use `num_train_epochs` instead\n",
|
||||
"\n",
|
||||
" # Parameters that control the data preprocessing\n",
|
||||
" per_device_train_batch_size=2,\n",
|
||||
" max_completion_length=1024, # default: 256 # Max completion length produced during training\n",
|
||||
" num_generations=2, # 2, # default: 8 # Number of generations produced during trainig for comparison\n",
|
||||
" max_prompt_length=2048, # default: 512 # Max prompt lenght of the input prompt used for generation during training\n",
|
||||
"\n",
|
||||
" fp16=True,\n",
|
||||
"\n",
|
||||
" # Parameters related to reporting and saving\n",
|
||||
" output_dir=output_dir, # Where to save model checkpoints and logs\n",
|
||||
" logging_steps=1, # Log training metrics every N steps\n",
|
||||
" report_to=\"trackio\", # Experiment tracking tool\n",
|
||||
"\n",
|
||||
" # Hub integration\n",
|
||||
" push_to_hub=True,\n",
|
||||
" log_completions=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "O0q3myQg927v"
|
||||
},
|
||||
"source": [
|
||||
"Configure the GRPO Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "z5JxkmS9HqD5",
|
||||
"outputId": "2b39338e-2194-4829-fc54-5e286566fd28"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.12/dist-packages/peft/mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
|
||||
" warnings.warn(\n",
|
||||
"/usr/local/lib/python3.12/dist-packages/peft/tuners/tuners_utils.py:196: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from trl import GRPOTrainer\n",
|
||||
"\n",
|
||||
"trainer = GRPOTrainer(\n",
|
||||
" model=model,\n",
|
||||
" reward_funcs=[format_reward, len_reward],\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
" peft_config=peft_config,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kQC7Q5kg95xq"
|
||||
},
|
||||
"source": [
|
||||
"Show memory stats before training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "naG_7qlYyBP6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gpu_stats = torch.cuda.get_device_properties(0)\n",
|
||||
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
||||
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
|
||||
"\n",
|
||||
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
|
||||
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YazYtLAe97Dc"
|
||||
},
|
||||
"source": [
|
||||
"And train!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "pbJXrhA0ywra"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer_stats = trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SmcYN5yW99IP"
|
||||
},
|
||||
"source": [
|
||||
"Show memory stats after training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "TrrwP4ADMmrp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
||||
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
|
||||
"used_percentage = round(used_memory / max_memory * 100, 3)\n",
|
||||
"lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
|
||||
"\n",
|
||||
"print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
|
||||
"print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
|
||||
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
|
||||
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
|
||||
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
|
||||
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "saarW87Y9_-R"
|
||||
},
|
||||
"source": [
|
||||
"## Saving fine tuned model\n",
|
||||
"\n",
|
||||
"In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "71A8aqEyyETA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.save_model(output_dir)\n",
|
||||
"trainer.push_to_hub(dataset_name=dataset_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nfqvO0qw-OvS"
|
||||
},
|
||||
"source": [
|
||||
"## Load the fine-tuned model and run inference\n",
|
||||
"\n",
|
||||
"Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "R8T2uFQVyFeH"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Qwen3VLForConditionalGeneration, AutoProcessor\n",
|
||||
"from peft import PeftModel\n",
|
||||
"\n",
|
||||
"base_model = model_name\n",
|
||||
"adapter_model = f\"{output_dir}\" # Replace with your HF username or organization\n",
|
||||
"\n",
|
||||
"model = Qwen3VLForConditionalGeneration.from_pretrained(base_model, dtype=\"auto\", device_map=\"auto\")\n",
|
||||
"model = PeftModel.from_pretrained(model, adapter_model)\n",
|
||||
"\n",
|
||||
"processor = AutoProcessor.from_pretrained(base_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dPBHP0CpLa6K"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "cG5-ccGRyHgo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
|
||||
"train_dataset = load_dataset(dataset_id, split='train[:5%]')\n",
|
||||
"\n",
|
||||
"problem = train_dataset[0]['problem']\n",
|
||||
"image = train_dataset[0]['image']\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\", \"content\": [\n",
|
||||
" {\"type\": \"text\", \"text\": SYSTEM_PROMPT}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": image},\n",
|
||||
" {\"type\": \"text\", \"text\": problem},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "r_70q_8lLgfV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "PX92MjqlyIwB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"# Inference: Generation of the output\n",
|
||||
"generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
|
||||
"generated_ids_trimmed = [\n",
|
||||
" out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
|
||||
"]\n",
|
||||
"output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
")\n",
|
||||
"print(output_text)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
515
examples/notebooks/sft_qwen_vl.ipynb
Normal file
515
examples/notebooks/sft_qwen_vl.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user