mirror of
https://github.com/huggingface/trl.git
synced 2025-11-12 01:04:41 +08:00
Signed-off-by: Yao, Matrix <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com>
610 lines
22 KiB
Plaintext
610 lines
22 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WQpNapZNWuXP"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"**Best-of-n sampling as an alternative to RLHF**\n",
|
|
"\n",
|
|
"This notebook compares reward-model scores of prompt based responses from \n",
|
|
"1. a base model (`gpt2-imdb`)\n",
|
|
"2. `RLHF` tuned model based on this base-model \n",
|
|
"3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n",
|
|
"\n",
|
|
"Import dependencies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "vDA6qayz692w"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install transformers trl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "M1s_iNm773hM"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"from transformers import pipeline, AutoTokenizer\n",
|
|
"from datasets import load_dataset\n",
|
|
"\n",
|
|
"from trl import AutoModelForCausalLMWithValueHead\n",
|
|
"from trl.core import LengthSampler\n",
|
|
"\n",
|
|
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
|
|
"device = \"cpu\" if device is None else device"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Y7hyrIrO8tcY"
|
|
},
|
|
"source": [
|
|
"Various constants"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"id": "MqS3OM6Q8x6g"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"ref_model_name = \"lvwerra/gpt2-imdb\"\n",
|
|
"model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n",
|
|
"reward_model = \"lvwerra/distilbert-imdb\"\n",
|
|
"\n",
|
|
"N_BEST_OF = 4"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c1YcXeElg6or"
|
|
},
|
|
"source": [
|
|
"Models and tokenizers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "b855NrL181Hh"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n",
|
|
"\n",
|
|
"ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n",
|
|
"\n",
|
|
"reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n",
|
|
"\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n",
|
|
"\n",
|
|
"tokenizer.pad_token = tokenizer.eos_token\n",
|
|
"\n",
|
|
"# put models to accelerator\n",
|
|
"model.to(device)\n",
|
|
"ref_model.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Z1Cz0gCFhZYJ"
|
|
},
|
|
"source": [
|
|
"Dataset building"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"id": "LqLVEp5p_8XM"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 113700.67 examples/s]\n",
|
|
"Generating test split: 100%|██████████| 25000/25000 [00:00<00:00, 131049.39 examples/s]\n",
|
|
"Generating unsupervised split: 100%|██████████| 50000/50000 [00:00<00:00, 126486.39 examples/s]\n",
|
|
"Filter: 100%|██████████| 25000/25000 [00:00<00:00, 238843.61 examples/s]\n",
|
|
"Map: 0%| | 0/24895 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors\n",
|
|
"Map: 100%|██████████| 24895/24895 [00:17<00:00, 1462.36 examples/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def build_dataset(\n",
|
|
" tokenizer,\n",
|
|
" dataset_name=\"stanfordnlp/imdb\",\n",
|
|
" input_min_text_length=2,\n",
|
|
" input_max_text_length=8,\n",
|
|
"):\n",
|
|
" # load imdb with datasets\n",
|
|
" ds = load_dataset(dataset_name, split=\"train\")\n",
|
|
" ds = ds.rename_columns({\"text\": \"review\"})\n",
|
|
" ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
|
|
"\n",
|
|
" input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
|
|
"\n",
|
|
" def tokenize(sample):\n",
|
|
" sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
|
|
" sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
|
|
" return sample\n",
|
|
"\n",
|
|
" ds = ds.map(tokenize, batched=False)\n",
|
|
" ds.set_format(type=\"torch\")\n",
|
|
" return ds\n",
|
|
"\n",
|
|
"\n",
|
|
"dataset = build_dataset(tokenizer)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"id": "AqA2McjMAxNw"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"gen_kwargs = {\n",
|
|
" \"min_length\": -1,\n",
|
|
" \"top_k\": 0.0,\n",
|
|
" \"top_p\": 1.0,\n",
|
|
" \"do_sample\": True,\n",
|
|
" \"pad_token_id\": tokenizer.eos_token_id,\n",
|
|
"}\n",
|
|
"sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"id": "L_q4qs35AxcR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"output_min_length = 4\n",
|
|
"output_max_length = 16\n",
|
|
"output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
|
|
"\n",
|
|
"#### get a batch from the dataset\n",
|
|
"bs = 16\n",
|
|
"output_data = dict()\n",
|
|
"dataset.set_format(\"pandas\")\n",
|
|
"df_batch = dataset[:].sample(bs)\n",
|
|
"output_data[\"query\"] = df_batch[\"query\"].tolist()\n",
|
|
"query_tensors = df_batch[\"input_ids\"].tolist()\n",
|
|
"\n",
|
|
"# :: [Resp]\n",
|
|
"response_tensors_ref, response_tensors = [], []\n",
|
|
"# :: [[Resp]]\n",
|
|
"response_tensors_best_of = []"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "QVfpyHnZBLKY"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"Generation using various models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "-imZ7uEFBNbw"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for i in range(bs):\n",
|
|
" gen_len = output_length_sampler()\n",
|
|
"\n",
|
|
" query = torch.tensor(query_tensors[i])\n",
|
|
"\n",
|
|
" output = ref_model.generate(\n",
|
|
" query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
|
" ).squeeze()\n",
|
|
" response_tensors_ref.append(tokenizer.decode(output))\n",
|
|
"\n",
|
|
" output = model.generate(\n",
|
|
" query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
|
" ).squeeze()\n",
|
|
" response_tensors.append(tokenizer.decode(output))\n",
|
|
"\n",
|
|
" # generating copies of the same query for the Best-of-n sampling\n",
|
|
" queries = query.repeat((N_BEST_OF, 1))\n",
|
|
" output = ref_model.generate(\n",
|
|
" queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
|
" ).squeeze()\n",
|
|
" response_tensors_best_of.append(tokenizer.batch_decode(output))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Jp5FC0Y5h_Sf"
|
|
},
|
|
"source": [
|
|
"Scoring"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"id": "PyDbbAQ0F_h7"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"scores_ref = [\n",
|
|
" output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n",
|
|
"]\n",
|
|
"scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n",
|
|
"scores_best_of = []\n",
|
|
"for i, response in enumerate(response_tensors_best_of):\n",
|
|
" # base_score = scores_ref[i]\n",
|
|
" scores_best_of.append(\n",
|
|
" torch.tensor(\n",
|
|
" [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n",
|
|
" )\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 682
|
|
},
|
|
"id": "nA1GDNJEiGm-",
|
|
"outputId": "1389c686-0751-4304-dea2-b71fd68748e1"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>query</th>\n",
|
|
" <th>response (ref)</th>\n",
|
|
" <th>scores (ref)</th>\n",
|
|
" <th>response (RLHF)</th>\n",
|
|
" <th>scores (RLHF)</th>\n",
|
|
" <th>response (best_of)</th>\n",
|
|
" <th>scores (best_of)</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>This movie is one of</td>\n",
|
|
" <td>This movie is one of the most twisted films I</td>\n",
|
|
" <td>2.094254</td>\n",
|
|
" <td>This movie is one of the finest directors of the</td>\n",
|
|
" <td>2.726879</td>\n",
|
|
" <td>This movie is one of the best looking movies I</td>\n",
|
|
" <td>2.705925</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>one may</td>\n",
|
|
" <td>one may feel we are seeing more</td>\n",
|
|
" <td>1.478813</td>\n",
|
|
" <td>one may not have great assets,</td>\n",
|
|
" <td>0.420451</td>\n",
|
|
" <td>one may not be supported, terrible</td>\n",
|
|
" <td>2.043730</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>This is an amazing film,</td>\n",
|
|
" <td>This is an amazing film, one of our favorite g...</td>\n",
|
|
" <td>2.871389</td>\n",
|
|
" <td>This is an amazing film, with all thelike wond...</td>\n",
|
|
" <td>2.918770</td>\n",
|
|
" <td>This is an amazing film, very moving and this ...</td>\n",
|
|
" <td>2.871694</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>just below</td>\n",
|
|
" <td>just below)and makes it seem as</td>\n",
|
|
" <td>0.861618</td>\n",
|
|
" <td>just below the world capital is a man</td>\n",
|
|
" <td>0.238322</td>\n",
|
|
" <td>just below) in this beautiful comedy.</td>\n",
|
|
" <td>2.760033</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>Return To the</td>\n",
|
|
" <td>Return To the Museum. That film, called Bl</td>\n",
|
|
" <td>0.017376</td>\n",
|
|
" <td>Return To the East\" is a fascinating film,</td>\n",
|
|
" <td>2.648028</td>\n",
|
|
" <td>Return To the International: Miyazaki, by Ts</td>\n",
|
|
" <td>1.072344</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>5</th>\n",
|
|
" <td>Brando plays the ace jet</td>\n",
|
|
" <td>Brando plays the ace jet fighter pilot, who stops</td>\n",
|
|
" <td>0.565335</td>\n",
|
|
" <td>Brando plays the ace jet pilot, who's a</td>\n",
|
|
" <td>0.668954</td>\n",
|
|
" <td>Brando plays the ace jet pilot Charlie; his fo...</td>\n",
|
|
" <td>0.679582</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>6</th>\n",
|
|
" <td>And a rather U</td>\n",
|
|
" <td>And a rather Utopian horror movie and with good</td>\n",
|
|
" <td>2.245751</td>\n",
|
|
" <td>And a rather Utop Congressional Movie, with a 45</td>\n",
|
|
" <td>0.307100</td>\n",
|
|
" <td>And a rather U of A complete combination of wh...</td>\n",
|
|
" <td>2.209265</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>7</th>\n",
|
|
" <td>The plot of this movie hangs</td>\n",
|
|
" <td>The plot of this movie hangs in the balance as...</td>\n",
|
|
" <td>1.122540</td>\n",
|
|
" <td>The plot of this movie hangs out well. The who...</td>\n",
|
|
" <td>2.195263</td>\n",
|
|
" <td>The plot of this movie hangs together within t...</td>\n",
|
|
" <td>1.310783</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>8</th>\n",
|
|
" <td>This isn't</td>\n",
|
|
" <td>This isn't all that bad; as for my</td>\n",
|
|
" <td>0.623968</td>\n",
|
|
" <td>This isn't a good film because I loved it</td>\n",
|
|
" <td>1.694601</td>\n",
|
|
" <td>This isn't bad writing, powerful actors and sp...</td>\n",
|
|
" <td>1.835901</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>9</th>\n",
|
|
" <td>This movie was for a</td>\n",
|
|
" <td>This movie was for a good reason!' Uh, OK</td>\n",
|
|
" <td>0.437566</td>\n",
|
|
" <td>This movie was for a fun, and grand Robinson</td>\n",
|
|
" <td>2.531890</td>\n",
|
|
" <td>This movie was for a bastard.<br /><br</td>\n",
|
|
" <td>2.311337</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>10</th>\n",
|
|
" <td>witty. funny.</td>\n",
|
|
" <td>witty. funny.<|endoftext|></td>\n",
|
|
" <td>1.636344</td>\n",
|
|
" <td>witty. funny. funnier. more funny. funnier. fu...</td>\n",
|
|
" <td>2.132353</td>\n",
|
|
" <td>witty. funny. In the first scene the comical n...</td>\n",
|
|
" <td>2.164077</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>11</th>\n",
|
|
" <td>It's very hard</td>\n",
|
|
" <td>It's very hard to believe that anyone would en...</td>\n",
|
|
" <td>1.003727</td>\n",
|
|
" <td>It's very hard to wrap your mind around what h...</td>\n",
|
|
" <td>0.778888</td>\n",
|
|
" <td>It's very hard to wrap this up, due to lack of...</td>\n",
|
|
" <td>1.598843</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>12</th>\n",
|
|
" <td>Absolutely fantastic trash....this one</td>\n",
|
|
" <td>Absolutely fantastic trash....this one was hav...</td>\n",
|
|
" <td>1.350834</td>\n",
|
|
" <td>Absolutely fantastic trash....this one is a pe...</td>\n",
|
|
" <td>2.177587</td>\n",
|
|
" <td>Absolutely fantastic trash....this one ruins i...</td>\n",
|
|
" <td>2.221997</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>13</th>\n",
|
|
" <td>Prior to</td>\n",
|
|
" <td>Prior to this action film,</td>\n",
|
|
" <td>0.242474</td>\n",
|
|
" <td>Prior to Christian Kane's star</td>\n",
|
|
" <td>0.297408</td>\n",
|
|
" <td>Prior to his restoration, Passion</td>\n",
|
|
" <td>1.655534</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>14</th>\n",
|
|
" <td>i,</td>\n",
|
|
" <td>i, Marty Rathbun, Damon Wayans, Mark Watney and</td>\n",
|
|
" <td>0.105734</td>\n",
|
|
" <td>i, perhaps the great movie the director should...</td>\n",
|
|
" <td>1.336116</td>\n",
|
|
" <td>i, Martin was a thrill of 70s---wow!lee and Heath</td>\n",
|
|
" <td>2.277638</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>15</th>\n",
|
|
" <td>The film</td>\n",
|
|
" <td>The film takes a very grim craggy look</td>\n",
|
|
" <td>0.069017</td>\n",
|
|
" <td>The film is one of the best of that era</td>\n",
|
|
" <td>2.737825</td>\n",
|
|
" <td>The film's ambition was almost so great that its</td>\n",
|
|
" <td>2.357480</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" query \\\n",
|
|
"0 This movie is one of \n",
|
|
"1 one may \n",
|
|
"2 This is an amazing film, \n",
|
|
"3 just below \n",
|
|
"4 Return To the \n",
|
|
"5 Brando plays the ace jet \n",
|
|
"6 And a rather U \n",
|
|
"7 The plot of this movie hangs \n",
|
|
"8 This isn't \n",
|
|
"9 This movie was for a \n",
|
|
"10 witty. funny. \n",
|
|
"11 It's very hard \n",
|
|
"12 Absolutely fantastic trash....this one \n",
|
|
"13 Prior to \n",
|
|
"14 i, \n",
|
|
"15 The film \n",
|
|
"\n",
|
|
" response (ref) scores (ref) \\\n",
|
|
"0 This movie is one of the most twisted films I 2.094254 \n",
|
|
"1 one may feel we are seeing more 1.478813 \n",
|
|
"2 This is an amazing film, one of our favorite g... 2.871389 \n",
|
|
"3 just below)and makes it seem as 0.861618 \n",
|
|
"4 Return To the Museum. That film, called Bl 0.017376 \n",
|
|
"5 Brando plays the ace jet fighter pilot, who stops 0.565335 \n",
|
|
"6 And a rather Utopian horror movie and with good 2.245751 \n",
|
|
"7 The plot of this movie hangs in the balance as... 1.122540 \n",
|
|
"8 This isn't all that bad; as for my 0.623968 \n",
|
|
"9 This movie was for a good reason!' Uh, OK 0.437566 \n",
|
|
"10 witty. funny.<|endoftext|> 1.636344 \n",
|
|
"11 It's very hard to believe that anyone would en... 1.003727 \n",
|
|
"12 Absolutely fantastic trash....this one was hav... 1.350834 \n",
|
|
"13 Prior to this action film, 0.242474 \n",
|
|
"14 i, Marty Rathbun, Damon Wayans, Mark Watney and 0.105734 \n",
|
|
"15 The film takes a very grim craggy look 0.069017 \n",
|
|
"\n",
|
|
" response (RLHF) scores (RLHF) \\\n",
|
|
"0 This movie is one of the finest directors of the 2.726879 \n",
|
|
"1 one may not have great assets, 0.420451 \n",
|
|
"2 This is an amazing film, with all thelike wond... 2.918770 \n",
|
|
"3 just below the world capital is a man 0.238322 \n",
|
|
"4 Return To the East\" is a fascinating film, 2.648028 \n",
|
|
"5 Brando plays the ace jet pilot, who's a 0.668954 \n",
|
|
"6 And a rather Utop Congressional Movie, with a 45 0.307100 \n",
|
|
"7 The plot of this movie hangs out well. The who... 2.195263 \n",
|
|
"8 This isn't a good film because I loved it 1.694601 \n",
|
|
"9 This movie was for a fun, and grand Robinson 2.531890 \n",
|
|
"10 witty. funny. funnier. more funny. funnier. fu... 2.132353 \n",
|
|
"11 It's very hard to wrap your mind around what h... 0.778888 \n",
|
|
"12 Absolutely fantastic trash....this one is a pe... 2.177587 \n",
|
|
"13 Prior to Christian Kane's star 0.297408 \n",
|
|
"14 i, perhaps the great movie the director should... 1.336116 \n",
|
|
"15 The film is one of the best of that era 2.737825 \n",
|
|
"\n",
|
|
" response (best_of) scores (best_of) \n",
|
|
"0 This movie is one of the best looking movies I 2.705925 \n",
|
|
"1 one may not be supported, terrible 2.043730 \n",
|
|
"2 This is an amazing film, very moving and this ... 2.871694 \n",
|
|
"3 just below) in this beautiful comedy. 2.760033 \n",
|
|
"4 Return To the International: Miyazaki, by Ts 1.072344 \n",
|
|
"5 Brando plays the ace jet pilot Charlie; his fo... 0.679582 \n",
|
|
"6 And a rather U of A complete combination of wh... 2.209265 \n",
|
|
"7 The plot of this movie hangs together within t... 1.310783 \n",
|
|
"8 This isn't bad writing, powerful actors and sp... 1.835901 \n",
|
|
"9 This movie was for a bastard.<br /><br 2.311337 \n",
|
|
"10 witty. funny. In the first scene the comical n... 2.164077 \n",
|
|
"11 It's very hard to wrap this up, due to lack of... 1.598843 \n",
|
|
"12 Absolutely fantastic trash....this one ruins i... 2.221997 \n",
|
|
"13 Prior to his restoration, Passion 1.655534 \n",
|
|
"14 i, Martin was a thrill of 70s---wow!lee and Heath 2.277638 \n",
|
|
"15 The film's ambition was almost so great that its 2.357480 "
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output_data[\"response (ref)\"] = response_tensors_ref\n",
|
|
"output_data[\"scores (ref)\"] = scores_ref\n",
|
|
"output_data[\"response (RLHF)\"] = response_tensors\n",
|
|
"output_data[\"scores (RLHF)\"] = scores\n",
|
|
"output_data[\"response (best_of)\"] = [\n",
|
|
" response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)\n",
|
|
"]\n",
|
|
"output_data[\"scores (best_of)\"] = [a.max().item() for a in scores_best_of]\n",
|
|
"\n",
|
|
"\n",
|
|
"# store results in a dataframe\n",
|
|
"df_results = pd.DataFrame(output_data)\n",
|
|
"df_results"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"gpuClass": "standard",
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|