Dynamic input sizes (#35)
* change ppo input from tensor to list of tensors for varying shapes * update readme example with new input type * update docs * add listification of tensors need for new API * replace nans in tensors for wandb compatibility * add `listify_batch` helper function for backwards compatibility * update sentiment example with new api * update docs * update library * ignore wandb artifacts * update requirements * run experiment * replace respond to batch with generate * add experiment * update docs * fix action * fix action
2
.github/workflows/main.yml
vendored
@ -30,4 +30,4 @@ jobs:
|
||||
if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi
|
||||
- name: Run tests
|
||||
run: |
|
||||
nbdev_test_nbs
|
||||
nbdev_test_nbs --fname 'nbs/[!03|!04|!05|]*.ipynb'
|
||||
|
2
.gitignore
vendored
@ -139,3 +139,5 @@ checklink/cookies.txt
|
||||
# .gitconfig is now autogenerated
|
||||
.gitconfig
|
||||
|
||||
|
||||
nbs/wandb/
|
||||
|
46
README.md
@ -3,11 +3,11 @@
|
||||
|
||||
|
||||
## What is it?
|
||||
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗 Hugging Face ([link](https://github.com/huggingface/transformers)). Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.
|
||||
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformer`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point only decoder architectures such as GTP2 are implemented.
|
||||
|
||||
**Highlights:**
|
||||
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
|
||||
- PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
|
||||
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
|
||||
- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.
|
||||
|
||||
## How it works
|
||||
@ -29,27 +29,29 @@ This process is illustrated in the sketch below:
|
||||
|
||||
### Python package
|
||||
Install the library with pip:
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
`pip install trl`
|
||||
|
||||
### Repository
|
||||
### From source
|
||||
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
|
||||
|
||||
`git clone https://github.com/lvwerra/trl.git`
|
||||
|
||||
`cd tlr/`
|
||||
|
||||
`pip install -r requirements.txt`
|
||||
|
||||
```bash
|
||||
git clone https://github.com/lvwerra/trl.git
|
||||
cd tlr/
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
### Jupyter notebooks
|
||||
|
||||
If you run Jupyter notebooks you might need to run the following:
|
||||
|
||||
`jupyter nbextension enable --py --sys-prefix widgetsnbextension`
|
||||
```bash
|
||||
jupyter nbextension enable --py --sys-prefix widgetsnbextension
|
||||
```
|
||||
|
||||
For Jupyterlab additionally this command:
|
||||
|
||||
`jupyter labextension install @jupyter-widgets/jupyterlab-manager`
|
||||
```bash
|
||||
jupyter labextension install @jupyter-widgets/jupyterlab-manager
|
||||
```
|
||||
|
||||
## How to use
|
||||
|
||||
@ -70,7 +72,7 @@ gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# initialize trainer
|
||||
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
|
||||
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, **ppo_config)
|
||||
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config)
|
||||
|
||||
# encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
@ -82,14 +84,14 @@ response_txt = gpt2_tokenizer.decode(response_tensor[0,:])
|
||||
|
||||
# define a reward for response
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = torch.tensor([1.0])
|
||||
reward = [torch.tensor(1.0)]
|
||||
|
||||
# train model with ppo
|
||||
train_stats = ppo_trainer.step(query_tensor, response_tensor, reward)
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
||||
```
|
||||
|
||||
### Advanced example: IMDB sentiment
|
||||
For a detailed example check out the notebook *Tune GPT2 to generate positive reviews*, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
|
||||
For a detailed example check out the notebook `04-gpt2-sentiment-ppo-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="nbs/images/table_imdb_preview.png" width="800">
|
||||
@ -104,8 +106,10 @@ This library is built with `nbdev` and as such all the library code as well as e
|
||||
- `00-core.ipynb`: Contains the utility functions used throughout the library and examples.
|
||||
- `01-gpt2-with-value-head.ipynb`: Implementation of a `transformer` compatible GPT2 model with an additional value head as well as a function to generate sequences.
|
||||
- `02-ppo.ipynb`: Implementation of the PPOTrainer used to train language models.
|
||||
- `03-bert-imdb-training.ipynb`: Training of BERT with `simpletransformers` to classify sentiment on the IMDB dataset.
|
||||
- `03-bert-imdb-training.ipynb`: Training of DistilBERT to classify sentiment on the IMDB dataset.
|
||||
- `04-gpt2-sentiment-ppo-training.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.
|
||||
|
||||
Currently using `trl==0.0.3`:
|
||||
- `05-gpt2-sentiment-control.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment.
|
||||
|
||||
## References
|
||||
@ -114,4 +118,4 @@ This library is built with `nbdev` and as such all the library code as well as e
|
||||
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
|
||||
|
||||
### Language models
|
||||
The language models utilize the `transformer` library by 🤗Hugging Face.
|
||||
The language models utilize the `transformers` library by 🤗 Hugging Face.
|
||||
|
@ -31,6 +31,19 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
<div class="cell border-box-sizing code_cell rendered">
|
||||
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
|
||||
<div class="text_cell_render border-box-sizing rendered_html">
|
||||
<h2 id="Constants">Constants<a class="anchor-link" href="#Constants"> </a></h2>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% raw %}
|
||||
|
||||
<div class="cell border-box-sizing code_cell rendered">
|
||||
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
@ -66,7 +79,7 @@ description: "A set of utility functions used throughout the library."
|
||||
<span class="n">results</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">stats_dicts</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
|
||||
<span class="n">stats_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">d</span><span class="p">[</span><span class="n">k</span><span class="p">])</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">stats_dicts</span><span class="p">]</span>
|
||||
<span class="n">results</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">stats_list</span><span class="p">)</span>
|
||||
<span class="n">results</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">pad_sequence</span><span class="p">(</span><span class="n">stats_list</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="n">WANDB_PADDING</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">results</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">add_suffix</span><span class="p">(</span><span class="n">input_dict</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span>
|
||||
@ -92,7 +105,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L14" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
|
||||
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L20" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
|
||||
</blockquote>
|
||||
<p>Flatten dictionary and concatenate nested keys with separator.</p>
|
||||
|
||||
@ -117,7 +130,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L28" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
|
||||
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L34" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Stack the values of a dict.</p>
|
||||
|
||||
@ -142,7 +155,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L36" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
|
||||
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Add suffix to dict keys.</p>
|
||||
|
||||
@ -227,6 +240,10 @@ description: "A set of utility functions used throughout the library."
|
||||
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]):</span>
|
||||
<span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
|
||||
<span class="k">return</span> <span class="n">new_dict</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">listify_batch</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
|
||||
<span class="sd">"""Turns the first dimension of a tensor into a list."""</span>
|
||||
<span class="k">return</span> <span class="p">[</span><span class="n">tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
@ -247,7 +264,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
|
||||
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L48" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
|
||||
</blockquote>
|
||||
<p>Pad tensor to size.</p>
|
||||
|
||||
@ -272,7 +289,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L50" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
|
||||
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L56" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>See: <a href="https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591">https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591</a></p>
|
||||
|
||||
@ -297,7 +314,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L59" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
|
||||
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L65" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
|
||||
</blockquote>
|
||||
<p>Whiten values.</p>
|
||||
|
||||
@ -322,7 +339,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L67" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
|
||||
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L73" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Tensor extenstion to torch.clamp
|
||||
<a href="https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713">https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713</a></p>
|
||||
@ -348,7 +365,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L75" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
|
||||
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L81" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Calculate entropy from logits.</p>
|
||||
|
||||
@ -373,7 +390,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L82" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
|
||||
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L88" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Average values of a list of dicts wiht torch tensors.</p>
|
||||
|
||||
@ -398,7 +415,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L89" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
|
||||
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L95" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Cast all torch.tensors in dict to numpy arrays.</p>
|
||||
|
||||
@ -409,6 +426,31 @@ description: "A set of utility functions used throughout the library."
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
{% raw %}
|
||||
|
||||
<div class="cell border-box-sizing code_cell rendered">
|
||||
|
||||
<div class="output_wrapper">
|
||||
<div class="output">
|
||||
|
||||
<div class="output_area">
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="listify_batch" class="doc_header"><code>listify_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L107" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>listify_batch</code>(<strong><code>tensor</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Turns the first dimension of a tensor into a list.</p>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
@ -468,7 +510,7 @@ description: "A set of utility functions used throughout the library."
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L104" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
|
||||
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L113" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Create token id and attention mask tensors from text list for BERT classification.</p>
|
||||
|
||||
|
@ -53,6 +53,56 @@ description: "A GPT2 model with a value head built on the `transformer` library
|
||||
<div class="cell border-box-sizing code_cell rendered">
|
||||
<div class="input">
|
||||
|
||||
<div class="inner_cell">
|
||||
<div class="input_area">
|
||||
<div class=" highlight hl-ipython3"><pre><span></span><span class="nd">@dataclass</span>
|
||||
<span class="k">class</span> <span class="nc">CausalLMOutputWithCrossAttentions</span><span class="p">(</span><span class="n">ModelOutput</span><span class="p">):</span>
|
||||
<span class="n">loss</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">past_key_values</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">attentions</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">cross_attentions</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
{% raw %}
|
||||
|
||||
<div class="cell border-box-sizing code_cell rendered">
|
||||
|
||||
<div class="output_wrapper">
|
||||
<div class="output">
|
||||
|
||||
<div class="output_area">
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h2 id="CausalLMOutputWithCrossAttentions" class="doc_header"><code>class</code> <code>CausalLMOutputWithCrossAttentions</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L18" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CausalLMOutputWithCrossAttentions</code>(<strong><code>loss</code></strong>:<code>Optional</code>[<code>FloatTensor</code>]=<em><code>None</code></em>, <strong><code>logits</code></strong>:<code>FloatTensor</code>=<em><code>None</code></em>, <strong><code>past_key_values</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>Tuple</code>[<code>FloatTensor</code>]]]=<em><code>None</code></em>, <strong><code>hidden_states</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>FloatTensor</code>]]=<em><code>None</code></em>, <strong><code>attentions</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>FloatTensor</code>]]=<em><code>None</code></em>, <strong><code>cross_attentions</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>FloatTensor</code>]]=<em><code>None</code></em>, <strong><code>value</code></strong>:<code>Optional</code>[<code>FloatTensor</code>]=<em><code>None</code></em>) :: <code>ModelOutput</code></p>
|
||||
</blockquote>
|
||||
<p>CausalLMOutputWithCrossAttentions(loss: Optional[torch.FloatTensor] = None, logits: torch.FloatTensor = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, hidden_states: Optional[Tuple[torch.FloatTensor]] = None, attentions: Optional[Tuple[torch.FloatTensor]] = None, cross_attentions: Optional[Tuple[torch.FloatTensor]] = None, value: Optional[torch.FloatTensor] = None)</p>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
{% raw %}
|
||||
|
||||
<div class="cell border-box-sizing code_cell rendered">
|
||||
<div class="input">
|
||||
|
||||
<div class="inner_cell">
|
||||
<div class="input_area">
|
||||
<div class=" highlight hl-ipython3"><pre><span></span><span class="k">class</span> <span class="nc">ValueHead</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
||||
@ -117,7 +167,7 @@ description: "A GPT2 model with a value head built on the `transformer` library
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h2 id="ValueHead" class="doc_header"><code>class</code> <code>ValueHead</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L16" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ValueHead</code>(<strong><code>config</code></strong>) :: <code>Module</code></p>
|
||||
<h2 id="ValueHead" class="doc_header"><code>class</code> <code>ValueHead</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L30" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ValueHead</code>(<strong><code>config</code></strong>) :: <code>Module</code></p>
|
||||
</blockquote>
|
||||
<p>The ValueHead class implements a head for GPT2 that returns a scalar for each output token.</p>
|
||||
|
||||
@ -167,8 +217,11 @@ description: "A GPT2 model with a value head built on the `transformer` library
|
||||
<span class="n">mc_token_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">lm_labels</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mc_labels</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">return_dict</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">output_attentions</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">output_hidden_states</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
|
||||
<span class="n">loss</span><span class="o">=</span><span class="kc">None</span>
|
||||
<span class="n">transformer_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="p">(</span>
|
||||
<span class="n">input_ids</span><span class="p">,</span>
|
||||
<span class="n">past_key_values</span><span class="o">=</span><span class="n">past_key_values</span><span class="p">,</span>
|
||||
@ -184,8 +237,20 @@ description: "A GPT2 model with a value head built on the `transformer` library
|
||||
<span class="n">lm_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">v_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">lm_logits</span><span class="p">,)</span> <span class="o">+</span> <span class="n">transformer_outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span><span class="p">,)</span>
|
||||
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">return_dict</span><span class="p">:</span>
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">lm_logits</span><span class="p">,)</span> <span class="o">+</span> <span class="n">transformer_outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span><span class="p">,)</span>
|
||||
<span class="k">return</span> <span class="n">outputs</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">CausalLMOutputWithCrossAttentions</span><span class="p">(</span>
|
||||
<span class="n">loss</span><span class="o">=</span><span class="n">loss</span><span class="p">,</span>
|
||||
<span class="n">logits</span><span class="o">=</span><span class="n">lm_logits</span><span class="p">,</span>
|
||||
<span class="n">past_key_values</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">past_key_values</span><span class="p">,</span>
|
||||
<span class="n">hidden_states</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">hidden_states</span><span class="p">,</span>
|
||||
<span class="n">attentions</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">attentions</span><span class="p">,</span>
|
||||
<span class="n">cross_attentions</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">cross_attentions</span><span class="p">,</span>
|
||||
<span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">outputs</span>
|
||||
</pre></div>
|
||||
|
||||
@ -207,7 +272,7 @@ description: "A GPT2 model with a value head built on the `transformer` library
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h2 id="GPT2HeadWithValueModel" class="doc_header"><code>class</code> <code>GPT2HeadWithValueModel</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L61" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>GPT2HeadWithValueModel</code>(<strong><code>config</code></strong>) :: <code>GPT2PreTrainedModel</code></p>
|
||||
<h2 id="GPT2HeadWithValueModel" class="doc_header"><code>class</code> <code>GPT2HeadWithValueModel</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L75" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>GPT2HeadWithValueModel</code>(<strong><code>config</code></strong>) :: <code>GPT2PreTrainedModel</code></p>
|
||||
</blockquote>
|
||||
<p>The GPT2HeadWithValueModel class implements a GPT2 language model with a secondary, scalar head.</p>
|
||||
|
||||
@ -243,6 +308,21 @@ description: "A GPT2 model with a value head built on the `transformer` library
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="output_wrapper">
|
||||
<div class="output">
|
||||
|
||||
<div class="output_area">
|
||||
|
||||
<div class="output_subarea output_stream output_stderr output_text">
|
||||
<pre>Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.6.attn.masked_bias', 'h.10.attn.masked_bias', 'h.0.attn.masked_bias', 'h.3.attn.masked_bias', 'h.7.attn.masked_bias', 'h.5.attn.masked_bias', 'h.11.attn.masked_bias', 'h.9.attn.masked_bias', 'h.8.attn.masked_bias', 'lm_head.weight', 'h.4.attn.masked_bias', 'v_head.summary.weight', 'h.2.attn.masked_bias', 'v_head.summary.bias', 'h.1.attn.masked_bias']
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
@ -508,7 +588,7 @@ description: "A GPT2 model with a value head built on the `transformer` library
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h4 id="respond_to_batch" class="doc_header"><code>respond_to_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L113" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>respond_to_batch</code>(<strong><code>model</code></strong>, <strong><code>queries</code></strong>, <strong><code>txt_len</code></strong>=<em><code>20</code></em>, <strong><code>top_k</code></strong>=<em><code>0</code></em>, <strong><code>top_p</code></strong>=<em><code>1.0</code></em>)</p>
|
||||
<h4 id="respond_to_batch" class="doc_header"><code>respond_to_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L142" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>respond_to_batch</code>(<strong><code>model</code></strong>, <strong><code>queries</code></strong>, <strong><code>txt_len</code></strong>=<em><code>20</code></em>, <strong><code>top_k</code></strong>=<em><code>0</code></em>, <strong><code>top_p</code></strong>=<em><code>1.0</code></em>)</p>
|
||||
</blockquote>
|
||||
<p>Sample text from language model.</p>
|
||||
|
||||
|
@ -91,7 +91,7 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h2 id="AdaptiveKLController" class="doc_header"><code>class</code> <code>AdaptiveKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L26" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>AdaptiveKLController</code>(<strong><code>init_kl_coef</code></strong>, <strong><code>target</code></strong>, <strong><code>horizon</code></strong>)</p>
|
||||
<h2 id="AdaptiveKLController" class="doc_header"><code>class</code> <code>AdaptiveKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L29" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>AdaptiveKLController</code>(<strong><code>init_kl_coef</code></strong>, <strong><code>target</code></strong>, <strong><code>horizon</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Adaptive KL controller described in the paper:
|
||||
<a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a></p>
|
||||
@ -140,7 +140,7 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L44" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>FixedKLController</code>(<strong><code>kl_coef</code></strong>)</p>
|
||||
<h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L47" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>FixedKLController</code>(<strong><code>kl_coef</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>Fixed KL controller.</p>
|
||||
|
||||
@ -182,13 +182,14 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
<span class="s2">"ppo_epochs"</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">ref_model</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_params</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">ref_model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_params</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Initialize PPOTrainer.</span>
|
||||
<span class="sd"> </span>
|
||||
<span class="sd"> Args:</span>
|
||||
<span class="sd"> model (torch.model): Hugging Face transformer GPT2 model with value head</span>
|
||||
<span class="sd"> ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty</span>
|
||||
<span class="sd"> tokenizer (tokenizer): Hugging Face tokenizer</span>
|
||||
<span class="sd"> ppo_params (dict or None): PPO parameters for training. Can include following keys:</span>
|
||||
<span class="sd"> 'lr' (float): Adam learning rate, default: 1.41e-5</span>
|
||||
<span class="sd"> 'batch_size' (int): Number of samples per optimisation step, default: 256</span>
|
||||
@ -210,6 +211,9 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ref_model</span> <span class="o">=</span> <span class="n">ref_model</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">tokenizer</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">data_collator</span> <span class="o">=</span> <span class="n">DataCollatorForLanguageModeling</span><span class="p">(</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">mlm</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'lr'</span><span class="p">])</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'adap_kl_ctrl'</span><span class="p">]:</span>
|
||||
@ -220,32 +224,33 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span> <span class="o">=</span> <span class="n">FixedKLController</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'init_kl_coef'</span><span class="p">])</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">,</span> <span class="n">scores</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">queries</span><span class="p">,</span> <span class="n">responses</span><span class="p">,</span> <span class="n">scores</span><span class="p">):</span>
|
||||
<span class="sd">"""</span>
|
||||
<span class="sd"> Run a PPO optimisation step.</span>
|
||||
<span class="sd"> </span>
|
||||
<span class="sd"> args:</span>
|
||||
<span class="sd"> query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]</span>
|
||||
<span class="sd"> response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]</span>
|
||||
<span class="sd"> scores (torch.tensor): tensor containing the scores, shape [batch_size]</span>
|
||||
<span class="sd"> queries (List): List of tensors containing the encoded queries, shape [query_length]</span>
|
||||
<span class="sd"> responses (List): List of tensors containing the encoded responses, shape [response_length]</span>
|
||||
<span class="sd"> scores (List): tensor containing the scores, shape [batch_size]</span>
|
||||
<span class="sd"> </span>
|
||||
<span class="sd"> returns:</span>
|
||||
<span class="sd"> train_stats (dict): a summary of the training statistics</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="n">bs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'batch_size'</span><span class="p">]</span>
|
||||
<span class="k">assert</span> <span class="n">bs</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">queries</span><span class="p">),</span> <span class="sa">f</span><span class="s2">"Batch size (</span><span class="si">{</span><span class="n">bs</span><span class="si">}</span><span class="s2">) does not match number of examples (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">queries</span><span class="p">)</span><span class="si">}</span><span class="s2">)"</span>
|
||||
|
||||
<span class="n">timing</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
<span class="n">t0</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||||
|
||||
<span class="n">gen_len</span> <span class="o">=</span> <span class="n">response</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">model_input</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">response_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">r</span><span class="p">)</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">responses</span><span class="p">]</span>
|
||||
|
||||
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||||
<span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">,</span> <span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batched_forward_pass</span><span class="p">(</span><span class="n">model_input</span><span class="p">,</span> <span class="n">gen_len</span><span class="p">)</span>
|
||||
<span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">,</span> <span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batched_forward_pass</span><span class="p">(</span><span class="n">queries</span><span class="p">,</span> <span class="n">responses</span><span class="p">)</span>
|
||||
<span class="n">timing</span><span class="p">[</span><span class="s1">'time/ppo/forward_pass'</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
|
||||
|
||||
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||||
<span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_reward</span><span class="p">,</span> <span class="n">kl_coef</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_rewards</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">)</span>
|
||||
<span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_rewards</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">)</span>
|
||||
<span class="n">timing</span><span class="p">[</span><span class="s1">'time/ppo/compute_rewards'</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
|
||||
|
||||
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||||
@ -255,9 +260,10 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">idxs</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">bs</span><span class="p">):</span>
|
||||
<span class="n">idx</span> <span class="o">=</span> <span class="n">idxs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
<span class="n">train_stats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_minibatch</span><span class="p">(</span><span class="n">logprobs</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">values</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
|
||||
<span class="n">rewards</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">query</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
|
||||
<span class="n">response</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">model_input</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">])</span>
|
||||
<span class="n">train_stats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_minibatch</span><span class="p">(</span><span class="n">logprobs</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">values</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">rewards</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">queries</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">responses</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">queries</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span><span class="n">responses</span><span class="p">[</span><span class="n">idx</span><span class="p">]])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||||
<span class="n">all_stats</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">train_stats</span><span class="p">)</span>
|
||||
<span class="n">timing</span><span class="p">[</span><span class="s1">'time/ppo/optimize_step'</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
|
||||
|
||||
@ -266,11 +272,12 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
|
||||
<span class="c1"># reshape advantages/ratios such that they are not averaged.</span>
|
||||
<span class="n">train_stats</span><span class="p">[</span><span class="s1">'policy/advantages'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">train_stats</span><span class="p">[</span><span class="s1">'policy/advantages'</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">train_stats</span><span class="p">[</span><span class="s1">'policy/advantages'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nan_to_num</span><span class="p">(</span><span class="n">train_stats</span><span class="p">[</span><span class="s1">'policy/advantages'</span><span class="p">],</span> <span class="n">WANDB_PADDING</span><span class="p">)</span>
|
||||
<span class="n">train_stats</span><span class="p">[</span><span class="s1">'policy/ratio'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">train_stats</span><span class="p">[</span><span class="s1">'policy/ratio'</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="n">stats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">record_step_stats</span><span class="p">(</span><span class="n">scores</span><span class="o">=</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="o">=</span><span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="o">=</span><span class="n">ref_logprobs</span><span class="p">,</span>
|
||||
<span class="n">non_score_reward</span><span class="o">=</span><span class="n">non_score_reward</span><span class="p">,</span> <span class="n">train_stats</span><span class="o">=</span><span class="n">train_stats</span><span class="p">,</span>
|
||||
<span class="n">kl_coef</span><span class="o">=</span><span class="n">kl_coef</span><span class="p">)</span>
|
||||
<span class="n">kl_coef</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
|
||||
<span class="n">stats</span> <span class="o">=</span> <span class="n">stats_to_np</span><span class="p">(</span><span class="n">stats</span><span class="p">)</span>
|
||||
<span class="n">timing</span><span class="p">[</span><span class="s1">'time/ppo/calc_stats'</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
|
||||
|
||||
@ -280,24 +287,30 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
<span class="n">stats</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">timing</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">stats</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">batched_forward_pass</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_input</span><span class="p">,</span> <span class="n">gen_len</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">batched_forward_pass</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">queries</span><span class="p">,</span> <span class="n">responses</span><span class="p">):</span>
|
||||
<span class="sd">"""Calculate model outputs in multiple batches."""</span>
|
||||
<span class="n">bs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'batch_size'</span><span class="p">]</span>
|
||||
<span class="n">fbs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'forward_batch_size'</span><span class="p">]</span>
|
||||
<span class="n">logprobs</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">ref_logprobs</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">all_logprobs</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">all_ref_logprobs</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">all_values</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'batch_size'</span><span class="p">]</span><span class="o">/</span><span class="n">fbs</span><span class="p">)):</span>
|
||||
<span class="n">m_input</span> <span class="o">=</span> <span class="n">model_input</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">fbs</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">fbs</span><span class="p">]</span>
|
||||
<span class="n">logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">m_input</span><span class="p">)</span>
|
||||
<span class="n">ref_logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ref_model</span><span class="p">(</span><span class="n">m_input</span><span class="p">)</span>
|
||||
|
||||
<span class="n">values</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">v</span><span class="p">[:,</span> <span class="o">-</span><span class="n">gen_len</span><span class="o">-</span><span class="mi">1</span><span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
|
||||
<span class="n">logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">m_input</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])[:,</span> <span class="o">-</span><span class="n">gen_len</span><span class="p">:]</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
|
||||
<span class="n">ref_logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">ref_logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">m_input</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])[:,</span> <span class="o">-</span><span class="n">gen_len</span><span class="p">:]</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">logprobs</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">ref_logprobs</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">values</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">bs</span><span class="o">/</span><span class="n">fbs</span><span class="p">)):</span>
|
||||
<span class="n">query_batch</span> <span class="o">=</span> <span class="n">queries</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">fbs</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">fbs</span><span class="p">]</span>
|
||||
<span class="n">response_batch</span> <span class="o">=</span> <span class="n">responses</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">fbs</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">fbs</span><span class="p">]</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_collator</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">q</span><span class="p">,</span> <span class="n">r</span><span class="p">])</span> <span class="k">for</span> <span class="n">q</span><span class="p">,</span> <span class="n">r</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">query_batch</span><span class="p">,</span> <span class="n">response_batch</span><span class="p">)])[</span><span class="s2">"input_ids"</span><span class="p">]</span>
|
||||
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
|
||||
<span class="n">logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
|
||||
<span class="n">ref_logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ref_model</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
|
||||
<span class="n">logprobs</span> <span class="o">=</span> <span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">input_ids</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])</span>
|
||||
<span class="n">ref_logprobs</span> <span class="o">=</span> <span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">ref_logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">input_ids</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])</span>
|
||||
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">fbs</span><span class="p">):</span>
|
||||
<span class="n">start</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">query_batch</span><span class="p">[</span><span class="n">j</span><span class="p">])</span><span class="o">-</span><span class="mi">1</span>
|
||||
<span class="n">end</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">query_batch</span><span class="p">[</span><span class="n">j</span><span class="p">])</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">response_batch</span><span class="p">[</span><span class="n">j</span><span class="p">])</span><span class="o">-</span><span class="mi">1</span>
|
||||
<span class="n">all_values</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">v</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">start</span><span class="o">-</span><span class="mi">1</span><span class="p">:</span><span class="n">end</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||||
<span class="n">all_logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">logprobs</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">])</span>
|
||||
<span class="n">all_ref_logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ref_logprobs</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">])</span>
|
||||
<span class="k">return</span> <span class="n">all_logprobs</span><span class="p">,</span> <span class="n">all_ref_logprobs</span><span class="p">,</span> <span class="n">all_values</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">train_minibatch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">,</span> <span class="n">model_input</span><span class="p">):</span>
|
||||
<span class="sd">"""Train one PPO minibatch"""</span>
|
||||
@ -310,18 +323,22 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
|
||||
<span class="k">def</span> <span class="nf">compute_rewards</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">):</span>
|
||||
<span class="sd">"""Compute per token rewards from scores and KL-penalty."""</span>
|
||||
<span class="n">kl</span> <span class="o">=</span> <span class="n">logprobs</span> <span class="o">-</span> <span class="n">ref_logprobs</span>
|
||||
<span class="n">non_score_reward</span> <span class="o">=</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span> <span class="o">*</span> <span class="n">kl</span>
|
||||
<span class="n">rewards</span> <span class="o">=</span> <span class="n">non_score_reward</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
|
||||
<span class="n">rewards</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+=</span> <span class="n">scores</span>
|
||||
<span class="k">return</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_reward</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_rewards</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">score</span><span class="p">,</span> <span class="n">logprob</span><span class="p">,</span> <span class="n">ref_logprob</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">):</span>
|
||||
<span class="n">kl</span> <span class="o">=</span> <span class="n">logprob</span> <span class="o">-</span> <span class="n">ref_logprob</span>
|
||||
<span class="n">non_score_reward</span> <span class="o">=</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span> <span class="o">*</span> <span class="n">kl</span>
|
||||
<span class="n">non_score_rewards</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">non_score_reward</span><span class="p">)</span>
|
||||
<span class="n">reward</span> <span class="o">=</span> <span class="n">non_score_reward</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||||
<span class="n">reward</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+=</span> <span class="n">score</span>
|
||||
<span class="n">rewards</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">reward</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_rewards</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">old_logprobs</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">,</span> <span class="n">model_input</span><span class="p">):</span>
|
||||
<span class="sd">"""Calculate policy and value losses."""</span>
|
||||
<span class="n">lastgaelam</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">advantages_reversed</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">gen_len</span> <span class="o">=</span> <span class="n">response</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">gen_len</span><span class="p">)):</span>
|
||||
<span class="n">nextvalues</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="k">if</span> <span class="n">t</span> <span class="o"><</span> <span class="n">gen_len</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">else</span> <span class="mf">0.0</span>
|
||||
<span class="n">delta</span> <span class="o">=</span> <span class="n">rewards</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">'gamma'</span><span class="p">]</span> <span class="o">*</span> <span class="n">nextvalues</span> <span class="o">-</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
|
||||
@ -379,13 +396,13 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
|
||||
<span class="k">def</span> <span class="nf">record_step_stats</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kl_coef</span><span class="p">,</span> <span class="o">**</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="sd">"""Record training step statistics."""</span>
|
||||
<span class="n">kl</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">'logprobs'</span><span class="p">]</span> <span class="o">-</span> <span class="n">data</span><span class="p">[</span><span class="s1">'ref_logprobs'</span><span class="p">]</span>
|
||||
<span class="n">mean_kl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">kl</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">mean_entropy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="n">data</span><span class="p">[</span><span class="s1">'logprobs'</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">mean_non_score_reward</span> <span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s1">'non_score_reward'</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
|
||||
<span class="n">kl_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">logprobs</span><span class="o">-</span><span class="n">ref_logprobs</span> <span class="k">for</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s1">'logprobs'</span><span class="p">],</span> <span class="n">data</span><span class="p">[</span><span class="s1">'ref_logprobs'</span><span class="p">])]</span>
|
||||
<span class="n">mean_kl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">kl</span><span class="p">)</span> <span class="k">for</span> <span class="n">kl</span> <span class="ow">in</span> <span class="n">kl_list</span><span class="p">]))</span>
|
||||
<span class="n">mean_entropy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="n">log_probs</span><span class="p">)</span> <span class="k">for</span> <span class="n">log_probs</span> <span class="ow">in</span> <span class="n">data</span><span class="p">[</span><span class="s1">'logprobs'</span><span class="p">]]))</span>
|
||||
<span class="n">mean_non_score_reward</span> <span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">non_score_reward</span><span class="p">)</span> <span class="k">for</span> <span class="n">non_score_reward</span> <span class="ow">in</span> <span class="n">data</span><span class="p">[</span><span class="s1">'non_score_reward'</span><span class="p">]]))</span>
|
||||
<span class="n">stats</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s1">'objective/kl'</span><span class="p">:</span> <span class="n">mean_kl</span><span class="p">,</span>
|
||||
<span class="s1">'objective/kl_dist'</span><span class="p">:</span> <span class="n">kl</span><span class="p">,</span>
|
||||
<span class="s1">'objective/kl_dist'</span><span class="p">:</span> <span class="n">kl_list</span><span class="p">,</span>
|
||||
<span class="s1">'objective/logprobs'</span><span class="p">:</span> <span class="n">data</span><span class="p">[</span><span class="s1">'logprobs'</span><span class="p">],</span>
|
||||
<span class="s1">'objective/ref_logprobs'</span><span class="p">:</span> <span class="n">data</span><span class="p">[</span><span class="s1">'ref_logprobs'</span><span class="p">],</span>
|
||||
<span class="s1">'objective/kl_coef'</span><span class="p">:</span> <span class="n">kl_coef</span><span class="p">,</span>
|
||||
@ -417,7 +434,7 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
|
||||
|
||||
|
||||
<div class="output_markdown rendered_html output_subarea ">
|
||||
<h2 id="PPOTrainer" class="doc_header"><code>class</code> <code>PPOTrainer</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L54" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>PPOTrainer</code>(<strong><code>model</code></strong>, <strong><code>ref_model</code></strong>, <strong>**<code>ppo_params</code></strong>)</p>
|
||||
<h2 id="PPOTrainer" class="doc_header"><code>class</code> <code>PPOTrainer</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L57" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>PPOTrainer</code>(<strong><code>model</code></strong>, <strong><code>ref_model</code></strong>, <strong><code>tokenizer</code></strong>, <strong>**<code>ppo_params</code></strong>)</p>
|
||||
</blockquote>
|
||||
<p>The PPO_trainer uses Proximal Policy Optimization to optimise language models.</p>
|
||||
|
||||
|
@ -27,6 +27,13 @@ description: "Optimise GPT2 to produce IMDB movie reviews with controlled sentim
|
||||
</div>
|
||||
{% endraw %}
|
||||
|
||||
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
|
||||
<div class="text_cell_render border-box-sizing rendered_html">
|
||||
<p>{% include warning.html content='This notebook uses version <code>trl==0.0.3</code>.' %}</p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
|
||||
<div class="text_cell_render border-box-sizing rendered_html">
|
||||
<div style="text-align: center">
|
||||
|
Before Width: | Height: | Size: 861 KiB After Width: | Height: | Size: 737 KiB |
Before Width: | Height: | Size: 355 KiB After Width: | Height: | Size: 1.1 MiB |
@ -29,11 +29,11 @@ description: "Train transformer language models with reinforcement learning."
|
||||
|
||||
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
|
||||
<div class="text_cell_render border-box-sizing rendered_html">
|
||||
<h2 id="What-is-it?">What is it?<a class="anchor-link" href="#What-is-it?"> </a></h2><p>With <code>trl</code> you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the <code>transformer</code> library by 🤗 Hugging Face (<a href="https://github.com/huggingface/transformers">link</a>). Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.</p>
|
||||
<h2 id="What-is-it?">What is it?<a class="anchor-link" href="#What-is-it?"> </a></h2><p>With <code>trl</code> you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the <code>transformer</code> library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via <code>transformers</code>. At this point only decoder architectures such as GTP2 are implemented.</p>
|
||||
<p><strong>Highlights:</strong></p>
|
||||
<ul>
|
||||
<li>GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.</li>
|
||||
<li>PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.</li>
|
||||
<li>GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.</li>
|
||||
<li>Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.</li>
|
||||
</ul>
|
||||
|
||||
@ -65,15 +65,19 @@ description: "Train transformer language models with reinforcement learning."
|
||||
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
|
||||
<div class="text_cell_render border-box-sizing rendered_html">
|
||||
<h3 id="Python-package">Python package<a class="anchor-link" href="#Python-package"> </a></h3><p>Install the library with pip:</p>
|
||||
<p><code>pip install trl</code></p>
|
||||
<h3 id="Repository">Repository<a class="anchor-link" href="#Repository"> </a></h3><p>If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:</p>
|
||||
<p><code>git clone https://github.com/lvwerra/trl.git</code></p>
|
||||
<p><code>cd tlr/</code></p>
|
||||
<p><code>pip install -r requirements.txt</code></p>
|
||||
<div class="highlight"><pre><span></span>pip install trl
|
||||
</pre></div>
|
||||
<h3 id="From-source">From source<a class="anchor-link" href="#From-source"> </a></h3><p>If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:</p>
|
||||
<div class="highlight"><pre><span></span>git clone https://github.com/lvwerra/trl.git
|
||||
<span class="nb">cd</span> tlr/
|
||||
pip install -r requirements.txt
|
||||
</pre></div>
|
||||
<h3 id="Jupyter-notebooks">Jupyter notebooks<a class="anchor-link" href="#Jupyter-notebooks"> </a></h3><p>If you run Jupyter notebooks you might need to run the following:</p>
|
||||
<p><code>jupyter nbextension enable --py --sys-prefix widgetsnbextension</code></p>
|
||||
<div class="highlight"><pre><span></span>jupyter nbextension <span class="nb">enable</span> --py --sys-prefix widgetsnbextension
|
||||
</pre></div>
|
||||
<p>For Jupyterlab additionally this command:</p>
|
||||
<p><code>jupyter labextension install @jupyter-widgets/jupyterlab-manager</code></p>
|
||||
<div class="highlight"><pre><span></span>jupyter labextension install @jupyter-widgets/jupyterlab-manager
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@ -111,7 +115,7 @@ description: "Train transformer language models with reinforcement learning."
|
||||
|
||||
<span class="c1"># initialize trainer</span>
|
||||
<span class="n">ppo_config</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">'forward_batch_size'</span><span class="p">:</span> <span class="mi">1</span><span class="p">}</span>
|
||||
<span class="n">ppo_trainer</span> <span class="o">=</span> <span class="n">PPOTrainer</span><span class="p">(</span><span class="n">gpt2_model</span><span class="p">,</span> <span class="n">gpt2_model_ref</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_config</span><span class="p">)</span>
|
||||
<span class="n">ppo_trainer</span> <span class="o">=</span> <span class="n">PPOTrainer</span><span class="p">(</span><span class="n">gpt2_model</span><span class="p">,</span> <span class="n">gpt2_model_ref</span><span class="p">,</span> <span class="n">gpt2_tokenizer</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_config</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># encode a query</span>
|
||||
<span class="n">query_txt</span> <span class="o">=</span> <span class="s2">"This morning I went to the "</span>
|
||||
@ -123,10 +127,10 @@ description: "Train transformer language models with reinforcement learning."
|
||||
|
||||
<span class="c1"># define a reward for response</span>
|
||||
<span class="c1"># (this could be any reward such as human feedback or output from another model)</span>
|
||||
<span class="n">reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.0</span><span class="p">])</span>
|
||||
<span class="n">reward</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)]</span>
|
||||
|
||||
<span class="c1"># train model with ppo</span>
|
||||
<span class="n">train_stats</span> <span class="o">=</span> <span class="n">ppo_trainer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">query_tensor</span><span class="p">,</span> <span class="n">response_tensor</span><span class="p">,</span> <span class="n">reward</span><span class="p">)</span>
|
||||
<span class="n">train_stats</span> <span class="o">=</span> <span class="n">ppo_trainer</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="n">query_tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="p">[</span><span class="n">response_tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="n">reward</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</div>
|
||||
@ -139,9 +143,9 @@ description: "Train transformer language models with reinforcement learning."
|
||||
<div class="output_area">
|
||||
|
||||
<div class="output_subarea output_stream output_stderr output_text">
|
||||
<pre>Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.9.attn.masked_bias', 'h.8.attn.masked_bias', 'lm_head.weight', 'h.4.attn.masked_bias', 'v_head.summary.weight', 'h.10.attn.masked_bias', 'h.0.attn.masked_bias', 'h.5.attn.masked_bias', 'h.7.attn.masked_bias', 'h.6.attn.masked_bias', 'h.1.attn.masked_bias', 'v_head.summary.bias', 'h.11.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias']
|
||||
<pre>Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['v_head.summary.weight', 'h.8.attn.masked_bias', 'h.0.attn.masked_bias', 'h.9.attn.masked_bias', 'h.1.attn.masked_bias', 'h.6.attn.masked_bias', 'h.5.attn.masked_bias', 'h.3.attn.masked_bias', 'lm_head.weight', 'v_head.summary.bias', 'h.2.attn.masked_bias', 'h.11.attn.masked_bias', 'h.7.attn.masked_bias', 'h.4.attn.masked_bias', 'h.10.attn.masked_bias']
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.9.attn.masked_bias', 'h.8.attn.masked_bias', 'lm_head.weight', 'h.4.attn.masked_bias', 'v_head.summary.weight', 'h.10.attn.masked_bias', 'h.0.attn.masked_bias', 'h.5.attn.masked_bias', 'h.7.attn.masked_bias', 'h.6.attn.masked_bias', 'h.1.attn.masked_bias', 'v_head.summary.bias', 'h.11.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias']
|
||||
Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['v_head.summary.weight', 'h.8.attn.masked_bias', 'h.0.attn.masked_bias', 'h.9.attn.masked_bias', 'h.1.attn.masked_bias', 'h.6.attn.masked_bias', 'h.5.attn.masked_bias', 'h.3.attn.masked_bias', 'lm_head.weight', 'v_head.summary.bias', 'h.2.attn.masked_bias', 'h.11.attn.masked_bias', 'h.7.attn.masked_bias', 'h.4.attn.masked_bias', 'h.10.attn.masked_bias']
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
</pre>
|
||||
</div>
|
||||
@ -155,7 +159,7 @@ You should probably TRAIN this model on a down-stream task to be able to use it
|
||||
|
||||
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
|
||||
<div class="text_cell_render border-box-sizing rendered_html">
|
||||
<h3 id="Advanced-example:-IMDB-sentiment">Advanced example: IMDB sentiment<a class="anchor-link" href="#Advanced-example:-IMDB-sentiment"> </a></h3><p>For a detailed example check out the notebook <em>Tune GPT2 to generate positive reviews</em>, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:</p>
|
||||
<h3 id="Advanced-example:-IMDB-sentiment">Advanced example: IMDB sentiment<a class="anchor-link" href="#Advanced-example:-IMDB-sentiment"> </a></h3><p>For a detailed example check out the notebook <code>04-gpt2-sentiment-ppo-training.ipynb</code>, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:</p>
|
||||
<div style="text-align: center">
|
||||
{% include image.html max-width="800" file="/trl/images/table_imdb_preview.png" %}
|
||||
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
|
||||
@ -171,8 +175,11 @@ You should probably TRAIN this model on a down-stream task to be able to use it
|
||||
<li><code>00-core.ipynb</code>: Contains the utility functions used throughout the library and examples.</li>
|
||||
<li><code>01-gpt2-with-value-head.ipynb</code>: Implementation of a <code>transformer</code> compatible GPT2 model with an additional value head as well as a function to generate sequences.</li>
|
||||
<li><code>02-ppo.ipynb</code>: Implementation of the PPOTrainer used to train language models.</li>
|
||||
<li><code>03-bert-imdb-training.ipynb</code>: Training of BERT with <code>simpletransformers</code> to classify sentiment on the IMDB dataset.</li>
|
||||
<li><code>03-bert-imdb-training.ipynb</code>: Training of DistilBERT to classify sentiment on the IMDB dataset.</li>
|
||||
<li><code>04-gpt2-sentiment-ppo-training.ipynb</code>: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.</li>
|
||||
</ul>
|
||||
<p>Currently using <code>trl==0.0.3</code>:</p>
|
||||
<ul>
|
||||
<li><code>05-gpt2-sentiment-control.ipynb</code>: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment.</li>
|
||||
</ul>
|
||||
|
||||
@ -182,7 +189,7 @@ You should probably TRAIN this model on a down-stream task to be able to use it
|
||||
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
|
||||
<div class="text_cell_render border-box-sizing rendered_html">
|
||||
<h2 id="References">References<a class="anchor-link" href="#References"> </a></h2><h3 id="Proximal-Policy-Optimisation">Proximal Policy Optimisation<a class="anchor-link" href="#Proximal-Policy-Optimisation"> </a></h3><p>The PPO implementation largely follows the structure introduced in the paper <strong>"Fine-Tuning Language Models from Human Preferences"</strong> by D. Ziegler et al. [<a href="https://arxiv.org/pdf/1909.08593.pdf">paper</a>, <a href="https://github.com/openai/lm-human-preferences">code</a>].</p>
|
||||
<h3 id="Language-models">Language models<a class="anchor-link" href="#Language-models"> </a></h3><p>The language models utilize the <code>transformer</code> library by 🤗Hugging Face.</p>
|
||||
<h3 id="Language-models">Language models<a class="anchor-link" href="#Language-models"> </a></h3><p>The language models utilize the <code>transformers</code> library by 🤗 Hugging Face.</p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
@ -26,10 +26,29 @@
|
||||
"# export\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.nn.utils.rnn import pad_sequence\n",
|
||||
"\n",
|
||||
"import collections\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Constants"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# export\n",
|
||||
"WANDB_PADDING = -1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@ -64,7 +83,7 @@
|
||||
" results = dict()\n",
|
||||
" for k in stats_dicts[0]:\n",
|
||||
" stats_list = [torch.flatten(d[k]) for d in stats_dicts]\n",
|
||||
" results[k] = torch.stack(stats_list)\n",
|
||||
" results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"def add_suffix(input_dict, suffix):\n",
|
||||
@ -144,7 +163,11 @@
|
||||
" new_dict[k] = v\n",
|
||||
" if np.isscalar(new_dict[k]):\n",
|
||||
" new_dict[k] = float(new_dict[k])\n",
|
||||
" return new_dict\n"
|
||||
" return new_dict\n",
|
||||
"\n",
|
||||
"def listify_batch(tensor):\n",
|
||||
" \"\"\"Turns the first dimension of a tensor into a list.\"\"\"\n",
|
||||
" return [tensor[i] for i in range(tensor.shape[0])]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -197,7 +220,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
|
@ -43,10 +43,31 @@
|
||||
"\n",
|
||||
"from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel\n",
|
||||
"from transformers import top_k_top_p_filtering\n",
|
||||
"from transformers.modeling_outputs import ModelOutput\n",
|
||||
"from torch import nn\n",
|
||||
"from torch.nn import Identity\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch"
|
||||
"import torch\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Optional, Tuple"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# exports\n",
|
||||
"@dataclass\n",
|
||||
"class CausalLMOutputWithCrossAttentions(ModelOutput):\n",
|
||||
" loss: Optional[torch.FloatTensor] = None\n",
|
||||
" logits: torch.FloatTensor = None\n",
|
||||
" past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n",
|
||||
" hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n",
|
||||
" attentions: Optional[Tuple[torch.FloatTensor]] = None\n",
|
||||
" cross_attentions: Optional[Tuple[torch.FloatTensor]] = None\n",
|
||||
" value: Optional[torch.FloatTensor] = None"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -138,8 +159,11 @@
|
||||
" mc_token_ids=None,\n",
|
||||
" lm_labels=None,\n",
|
||||
" mc_labels=None,\n",
|
||||
" return_dict=False,\n",
|
||||
" output_attentions=False,\n",
|
||||
" output_hidden_states=False,\n",
|
||||
" ):\n",
|
||||
" \n",
|
||||
" loss=None\n",
|
||||
" transformer_outputs = self.transformer(\n",
|
||||
" input_ids,\n",
|
||||
" past_key_values=past_key_values,\n",
|
||||
@ -155,8 +179,20 @@
|
||||
" lm_logits = self.lm_head(hidden_states)\n",
|
||||
" value = self.v_head(hidden_states).squeeze(-1)\n",
|
||||
"\n",
|
||||
" outputs = (lm_logits,) + transformer_outputs[1:] + (value,)\n",
|
||||
" \n",
|
||||
" if not return_dict:\n",
|
||||
" outputs = (lm_logits,) + transformer_outputs[1:] + (value,)\n",
|
||||
" return outputs\n",
|
||||
"\n",
|
||||
" return CausalLMOutputWithCrossAttentions(\n",
|
||||
" loss=loss,\n",
|
||||
" logits=lm_logits,\n",
|
||||
" past_key_values=transformer_outputs.past_key_values,\n",
|
||||
" hidden_states=transformer_outputs.hidden_states,\n",
|
||||
" attentions=transformer_outputs.attentions,\n",
|
||||
" cross_attentions=transformer_outputs.cross_attentions,\n",
|
||||
" value=value,\n",
|
||||
" ) \n",
|
||||
" return outputs"
|
||||
]
|
||||
},
|
||||
@ -172,7 +208,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.6.attn.masked_bias', 'h.10.attn.masked_bias', 'h.0.attn.masked_bias', 'h.3.attn.masked_bias', 'h.7.attn.masked_bias', 'h.5.attn.masked_bias', 'h.11.attn.masked_bias', 'h.9.attn.masked_bias', 'h.8.attn.masked_bias', 'lm_head.weight', 'h.4.attn.masked_bias', 'v_head.summary.weight', 'h.2.attn.masked_bias', 'v_head.summary.bias', 'h.1.attn.masked_bias']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = GPT2HeadWithValueModel.from_pretrained('gpt2')\n",
|
||||
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2')"
|
||||
@ -464,7 +509,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
|
114
nbs/02-ppo.ipynb
@ -40,15 +40,18 @@
|
||||
"import time\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"from transformers import DataCollatorForLanguageModeling\n",
|
||||
"\n",
|
||||
"from trl.core import (logprobs_from_logits,\n",
|
||||
" whiten,\n",
|
||||
" clip_by_value,\n",
|
||||
" entropy_from_logits,\n",
|
||||
" flatten_dict,\n",
|
||||
" average_torch_dicts,\n",
|
||||
" stats_to_np,\n",
|
||||
" stack_dicts,\n",
|
||||
" add_suffix)"
|
||||
" whiten,\n",
|
||||
" clip_by_value,\n",
|
||||
" entropy_from_logits,\n",
|
||||
" flatten_dict,\n",
|
||||
" average_torch_dicts,\n",
|
||||
" stats_to_np,\n",
|
||||
" stack_dicts,\n",
|
||||
" add_suffix,\n",
|
||||
" WANDB_PADDING)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -132,13 +135,14 @@
|
||||
" \"ppo_epochs\": 4, \n",
|
||||
" } \n",
|
||||
" \n",
|
||||
" def __init__(self, model, ref_model, **ppo_params):\n",
|
||||
" def __init__(self, model, ref_model, tokenizer, **ppo_params):\n",
|
||||
" \"\"\"\n",
|
||||
" Initialize PPOTrainer.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" model (torch.model): Hugging Face transformer GPT2 model with value head\n",
|
||||
" ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty\n",
|
||||
" tokenizer (tokenizer): Hugging Face tokenizer\n",
|
||||
" ppo_params (dict or None): PPO parameters for training. Can include following keys:\n",
|
||||
" 'lr' (float): Adam learning rate, default: 1.41e-5\n",
|
||||
" 'batch_size' (int): Number of samples per optimisation step, default: 256\n",
|
||||
@ -160,6 +164,9 @@
|
||||
" \n",
|
||||
" self.ref_model = ref_model\n",
|
||||
" self.model = model\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
|
||||
" \n",
|
||||
" self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])\n",
|
||||
" \n",
|
||||
" if self.ppo_params['adap_kl_ctrl']:\n",
|
||||
@ -170,32 +177,33 @@
|
||||
" self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef'])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def step(self, query, response, scores):\n",
|
||||
" def step(self, queries, responses, scores):\n",
|
||||
" \"\"\"\n",
|
||||
" Run a PPO optimisation step.\n",
|
||||
" \n",
|
||||
" args:\n",
|
||||
" query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]\n",
|
||||
" response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]\n",
|
||||
" scores (torch.tensor): tensor containing the scores, shape [batch_size]\n",
|
||||
" queries (List): List of tensors containing the encoded queries, shape [query_length]\n",
|
||||
" responses (List): List of tensors containing the encoded responses, shape [response_length]\n",
|
||||
" scores (List): tensor containing the scores, shape [batch_size]\n",
|
||||
" \n",
|
||||
" returns:\n",
|
||||
" train_stats (dict): a summary of the training statistics\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" bs = self.ppo_params['batch_size']\n",
|
||||
" assert bs == len(queries), f\"Batch size ({bs}) does not match number of examples ({len(queries)})\"\n",
|
||||
" \n",
|
||||
" timing = dict()\n",
|
||||
" t0 = time.time()\n",
|
||||
" \n",
|
||||
" gen_len = response.shape[1]\n",
|
||||
" model_input = torch.cat((query, response), axis=1)\n",
|
||||
" response_lengths = [len(r) for r in responses]\n",
|
||||
" \n",
|
||||
" t = time.time()\n",
|
||||
" logprobs, ref_logprobs, values = self.batched_forward_pass(model_input, gen_len)\n",
|
||||
" logprobs, ref_logprobs, values = self.batched_forward_pass(queries, responses)\n",
|
||||
" timing['time/ppo/forward_pass'] = time.time()-t\n",
|
||||
"\n",
|
||||
" t = time.time()\n",
|
||||
" rewards, non_score_reward, kl_coef = self.compute_rewards(scores, logprobs, ref_logprobs)\n",
|
||||
" rewards, non_score_reward = self.compute_rewards(scores, logprobs, ref_logprobs)\n",
|
||||
" timing['time/ppo/compute_rewards'] = time.time()-t \n",
|
||||
" \n",
|
||||
" t = time.time() \n",
|
||||
@ -205,9 +213,10 @@
|
||||
" random.shuffle(idxs)\n",
|
||||
" for i in range(bs):\n",
|
||||
" idx = idxs[i]\n",
|
||||
" train_stats = self.train_minibatch(logprobs[idx:idx+1], values[idx:idx+1],\n",
|
||||
" rewards[idx:idx+1], query[idx:idx+1],\n",
|
||||
" response[idx:idx+1], model_input[idx:idx+1])\n",
|
||||
" train_stats = self.train_minibatch(logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0),\n",
|
||||
" rewards[idx].unsqueeze(0), queries[idx].unsqueeze(0),\n",
|
||||
" responses[idx].unsqueeze(0),\n",
|
||||
" torch.cat([queries[idx],responses[idx]]).unsqueeze(0))\n",
|
||||
" all_stats.append(train_stats)\n",
|
||||
" timing['time/ppo/optimize_step'] = time.time()-t\n",
|
||||
" \n",
|
||||
@ -216,11 +225,12 @@
|
||||
" \n",
|
||||
" # reshape advantages/ratios such that they are not averaged.\n",
|
||||
" train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0)\n",
|
||||
" train_stats['policy/advantages'] = torch.nan_to_num(train_stats['policy/advantages'], WANDB_PADDING)\n",
|
||||
" train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0)\n",
|
||||
" \n",
|
||||
" stats = self.record_step_stats(scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs,\n",
|
||||
" non_score_reward=non_score_reward, train_stats=train_stats,\n",
|
||||
" kl_coef=kl_coef)\n",
|
||||
" kl_coef=self.kl_ctl.value)\n",
|
||||
" stats = stats_to_np(stats)\n",
|
||||
" timing['time/ppo/calc_stats'] = time.time()-t\n",
|
||||
"\n",
|
||||
@ -230,24 +240,30 @@
|
||||
" stats.update(timing)\n",
|
||||
" return stats\n",
|
||||
"\n",
|
||||
" def batched_forward_pass(self, model_input, gen_len):\n",
|
||||
" def batched_forward_pass(self, queries, responses):\n",
|
||||
" \"\"\"Calculate model outputs in multiple batches.\"\"\"\n",
|
||||
" bs = self.ppo_params['batch_size']\n",
|
||||
" fbs = self.ppo_params['forward_batch_size']\n",
|
||||
" logprobs = []\n",
|
||||
" ref_logprobs = []\n",
|
||||
" values = []\n",
|
||||
" all_logprobs = []\n",
|
||||
" all_ref_logprobs = []\n",
|
||||
" all_values = []\n",
|
||||
" \n",
|
||||
" for i in range(int(self.ppo_params['batch_size']/fbs)):\n",
|
||||
" m_input = model_input[i*fbs:(i+1)*fbs]\n",
|
||||
" logits, _, v = self.model(m_input)\n",
|
||||
" ref_logits, _, _ = self.ref_model(m_input)\n",
|
||||
" \n",
|
||||
" values.append(v[:, -gen_len-1:-1].detach())\n",
|
||||
" logprobs.append(logprobs_from_logits(logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())\n",
|
||||
" ref_logprobs.append(logprobs_from_logits(ref_logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())\n",
|
||||
" \n",
|
||||
" return torch.cat(logprobs), torch.cat(ref_logprobs), torch.cat(values)\n",
|
||||
" for i in range(int(bs/fbs)):\n",
|
||||
" query_batch = queries[i*fbs:(i+1)*fbs]\n",
|
||||
" response_batch = responses[i*fbs:(i+1)*fbs]\n",
|
||||
" input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])[\"input_ids\"]\n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits, _, v = self.model(input_ids)\n",
|
||||
" ref_logits, _, _ = self.ref_model(input_ids)\n",
|
||||
" logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:])\n",
|
||||
" ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:])\n",
|
||||
" for j in range(fbs):\n",
|
||||
" start = len(query_batch[j])-1\n",
|
||||
" end = len(query_batch[j]) + len(response_batch[j])-1\n",
|
||||
" all_values.append(v[j, start-1:end-1])\n",
|
||||
" all_logprobs.append(logprobs[j, start:end])\n",
|
||||
" all_ref_logprobs.append(ref_logprobs[j, start:end])\n",
|
||||
" return all_logprobs, all_ref_logprobs, all_values\n",
|
||||
" \n",
|
||||
" def train_minibatch(self, logprobs, values, rewards, query, response, model_input):\n",
|
||||
" \"\"\"Train one PPO minibatch\"\"\"\n",
|
||||
@ -260,18 +276,22 @@
|
||||
" \n",
|
||||
" def compute_rewards(self, scores, logprobs, ref_logprobs):\n",
|
||||
" \"\"\"Compute per token rewards from scores and KL-penalty.\"\"\"\n",
|
||||
" kl = logprobs - ref_logprobs\n",
|
||||
" non_score_reward = -self.kl_ctl.value * kl\n",
|
||||
" rewards = non_score_reward.clone().detach()\n",
|
||||
" rewards[:, -1] += scores\n",
|
||||
" return rewards, non_score_reward, self.kl_ctl.value\n",
|
||||
" rewards, non_score_rewards = [], []\n",
|
||||
" for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs):\n",
|
||||
" kl = logprob - ref_logprob\n",
|
||||
" non_score_reward = -self.kl_ctl.value * kl\n",
|
||||
" non_score_rewards.append(non_score_reward)\n",
|
||||
" reward = non_score_reward.clone()\n",
|
||||
" reward[-1] += score\n",
|
||||
" rewards.append(reward)\n",
|
||||
" return rewards, non_score_rewards\n",
|
||||
"\n",
|
||||
" def loss(self, old_logprobs, values, rewards, query, response, model_input):\n",
|
||||
" \"\"\"Calculate policy and value losses.\"\"\"\n",
|
||||
" lastgaelam = 0\n",
|
||||
" advantages_reversed = []\n",
|
||||
" gen_len = response.shape[1]\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" for t in reversed(range(gen_len)):\n",
|
||||
" nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0\n",
|
||||
" delta = rewards[:, t] + self.ppo_params['gamma'] * nextvalues - values[:, t]\n",
|
||||
@ -329,13 +349,13 @@
|
||||
"\n",
|
||||
" def record_step_stats(self, kl_coef, **data):\n",
|
||||
" \"\"\"Record training step statistics.\"\"\"\n",
|
||||
" kl = data['logprobs'] - data['ref_logprobs']\n",
|
||||
" mean_kl = torch.mean(torch.sum(kl, axis=-1))\n",
|
||||
" mean_entropy = torch.mean(torch.sum(-data['logprobs'], axis=1))\n",
|
||||
" mean_non_score_reward =torch.mean(torch.sum(data['non_score_reward'], axis=1))\n",
|
||||
" kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])] \n",
|
||||
" mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list]))\n",
|
||||
" mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']]))\n",
|
||||
" mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']]))\n",
|
||||
" stats = {\n",
|
||||
" 'objective/kl': mean_kl,\n",
|
||||
" 'objective/kl_dist': kl,\n",
|
||||
" 'objective/kl_dist': kl_list,\n",
|
||||
" 'objective/logprobs': data['logprobs'],\n",
|
||||
" 'objective/ref_logprobs': data['ref_logprobs'],\n",
|
||||
" 'objective/kl_coef': kl_coef,\n",
|
||||
@ -396,7 +416,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
|
@ -8,6 +8,13 @@
|
||||
"> Optimise GPT2 to produce IMDB movie reviews with controlled sentiment using a BERT sentiment classifier for rewards."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"> warning: This notebook uses version `trl==0.0.3`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@ -1644,18 +1651,6 @@
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Before Width: | Height: | Size: 861 KiB After Width: | Height: | Size: 737 KiB |
Before Width: | Height: | Size: 355 KiB After Width: | Height: | Size: 1.1 MiB |
@ -14,11 +14,11 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## What is it?\n",
|
||||
"With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗 Hugging Face ([link](https://github.com/huggingface/transformers)). Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.\n",
|
||||
"With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformer`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point only decoder architectures such as GTP2 are implemented.\n",
|
||||
"\n",
|
||||
"**Highlights:**\n",
|
||||
"- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.\n",
|
||||
"- PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.\n",
|
||||
"- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.\n",
|
||||
"- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier."
|
||||
]
|
||||
},
|
||||
@ -55,27 +55,29 @@
|
||||
"source": [
|
||||
"### Python package\n",
|
||||
"Install the library with pip:\n",
|
||||
"```bash\n",
|
||||
"pip install trl\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"`pip install trl`\n",
|
||||
"\n",
|
||||
"### Repository\n",
|
||||
"### From source\n",
|
||||
"If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:\n",
|
||||
"\n",
|
||||
"`git clone https://github.com/lvwerra/trl.git`\n",
|
||||
"\n",
|
||||
"`cd tlr/`\n",
|
||||
"\n",
|
||||
"`pip install -r requirements.txt`\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"git clone https://github.com/lvwerra/trl.git\n",
|
||||
"cd tlr/\n",
|
||||
"pip install -r requirements.txt\n",
|
||||
"```\n",
|
||||
"### Jupyter notebooks\n",
|
||||
"\n",
|
||||
"If you run Jupyter notebooks you might need to run the following:\n",
|
||||
"\n",
|
||||
"`jupyter nbextension enable --py --sys-prefix widgetsnbextension`\n",
|
||||
"```bash\n",
|
||||
"jupyter nbextension enable --py --sys-prefix widgetsnbextension\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"For Jupyterlab additionally this command:\n",
|
||||
"\n",
|
||||
"`jupyter labextension install @jupyter-widgets/jupyterlab-manager`"
|
||||
"```bash\n",
|
||||
"jupyter labextension install @jupyter-widgets/jupyterlab-manager\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -112,7 +114,7 @@
|
||||
"\n",
|
||||
"# initialize trainer\n",
|
||||
"ppo_config = {'batch_size': 1, 'forward_batch_size': 1}\n",
|
||||
"ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, **ppo_config)\n",
|
||||
"ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config)\n",
|
||||
"\n",
|
||||
"# encode a query\n",
|
||||
"query_txt = \"This morning I went to the \"\n",
|
||||
@ -124,10 +126,10 @@
|
||||
"\n",
|
||||
"# define a reward for response\n",
|
||||
"# (this could be any reward such as human feedback or output from another model)\n",
|
||||
"reward = torch.tensor([1.0]) \n",
|
||||
"reward = [torch.tensor(1.0)]\n",
|
||||
"\n",
|
||||
"# train model with ppo\n",
|
||||
"train_stats = ppo_trainer.step(query_tensor, response_tensor, reward)"
|
||||
"train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -135,7 +137,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Advanced example: IMDB sentiment\n",
|
||||
"For a detailed example check out the notebook *Tune GPT2 to generate positive reviews*, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:\n",
|
||||
"For a detailed example check out the notebook `04-gpt2-sentiment-ppo-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:\n",
|
||||
"\n",
|
||||
"<div style=\"text-align: center\">\n",
|
||||
"<img src='images/table_imdb_preview.png' width='800'>\n",
|
||||
@ -154,8 +156,10 @@
|
||||
"- `00-core.ipynb`: Contains the utility functions used throughout the library and examples.\n",
|
||||
"- `01-gpt2-with-value-head.ipynb`: Implementation of a `transformer` compatible GPT2 model with an additional value head as well as a function to generate sequences.\n",
|
||||
"- `02-ppo.ipynb`: Implementation of the PPOTrainer used to train language models.\n",
|
||||
"- `03-bert-imdb-training.ipynb`: Training of BERT with `simpletransformers` to classify sentiment on the IMDB dataset.\n",
|
||||
"- `03-bert-imdb-training.ipynb`: Training of DistilBERT to classify sentiment on the IMDB dataset.\n",
|
||||
"- `04-gpt2-sentiment-ppo-training.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.\n",
|
||||
"\n",
|
||||
"Currently using `trl==0.0.3`:\n",
|
||||
"- `05-gpt2-sentiment-control.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment."
|
||||
]
|
||||
},
|
||||
@ -169,7 +173,7 @@
|
||||
"The PPO implementation largely follows the structure introduced in the paper **\"Fine-Tuning Language Models from Human Preferences\"** by D. Ziegler et al. \\[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].\n",
|
||||
"\n",
|
||||
"### Language models\n",
|
||||
"The language models utilize the `transformer` library by 🤗Hugging Face."
|
||||
"The language models utilize the `transformers` library by 🤗 Hugging Face."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -182,7 +186,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ jupyterlab==2.2.10
|
||||
nbdev==0.2.16
|
||||
datasets==1.17.0
|
||||
torch>=1.4.0
|
||||
tqdm==4.43.0
|
||||
tqdm
|
||||
transformers==4.15.0
|
||||
wandb==0.10.20
|
||||
matplotlib==3.5.1
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
__all__ = ["index", "modules", "custom_doc_links", "git_url"]
|
||||
|
||||
index = {"flatten_dict": "00-core.ipynb",
|
||||
index = {"WANDB_PADDING": "00-core.ipynb",
|
||||
"flatten_dict": "00-core.ipynb",
|
||||
"stack_dicts": "00-core.ipynb",
|
||||
"add_suffix": "00-core.ipynb",
|
||||
"pad_to_size": "00-core.ipynb",
|
||||
@ -12,7 +13,9 @@ index = {"flatten_dict": "00-core.ipynb",
|
||||
"entropy_from_logits": "00-core.ipynb",
|
||||
"average_torch_dicts": "00-core.ipynb",
|
||||
"stats_to_np": "00-core.ipynb",
|
||||
"listify_batch": "00-core.ipynb",
|
||||
"build_bert_batch_from_txt": "00-core.ipynb",
|
||||
"CausalLMOutputWithCrossAttentions": "01-gpt2-with-value-head.ipynb",
|
||||
"ValueHead": "01-gpt2-with-value-head.ipynb",
|
||||
"GPT2HeadWithValueModel": "01-gpt2-with-value-head.ipynb",
|
||||
"respond_to_batch": "01-gpt2-with-value-head.ipynb",
|
||||
|
15
trl/core.py
@ -1,14 +1,20 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00-core.ipynb (unless otherwise specified).
|
||||
|
||||
__all__ = ['flatten_dict', 'stack_dicts', 'add_suffix', 'pad_to_size', 'logprobs_from_logits', 'whiten',
|
||||
'clip_by_value', 'entropy_from_logits', 'average_torch_dicts', 'stats_to_np', 'build_bert_batch_from_txt']
|
||||
__all__ = ['WANDB_PADDING', 'flatten_dict', 'stack_dicts', 'add_suffix', 'pad_to_size', 'logprobs_from_logits',
|
||||
'whiten', 'clip_by_value', 'entropy_from_logits', 'average_torch_dicts', 'stats_to_np', 'listify_batch',
|
||||
'build_bert_batch_from_txt']
|
||||
|
||||
# Cell
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
# Cell
|
||||
WANDB_PADDING = -1
|
||||
|
||||
# Cell
|
||||
|
||||
def flatten_dict(nested, sep='/'):
|
||||
@ -30,7 +36,7 @@ def stack_dicts(stats_dicts):
|
||||
results = dict()
|
||||
for k in stats_dicts[0]:
|
||||
stats_list = [torch.flatten(d[k]) for d in stats_dicts]
|
||||
results[k] = torch.stack(stats_list)
|
||||
results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
|
||||
return results
|
||||
|
||||
def add_suffix(input_dict, suffix):
|
||||
@ -98,6 +104,9 @@ def stats_to_np(stats_dict):
|
||||
new_dict[k] = float(new_dict[k])
|
||||
return new_dict
|
||||
|
||||
def listify_batch(tensor):
|
||||
"""Turns the first dimension of a tensor into a list."""
|
||||
return [tensor[i] for i in range(tensor.shape[0])]
|
||||
|
||||
# Cell
|
||||
|
||||
|
35
trl/gpt2.py
@ -1,15 +1,29 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01-gpt2-with-value-head.ipynb (unless otherwise specified).
|
||||
|
||||
__all__ = ['ValueHead', 'GPT2HeadWithValueModel', 'respond_to_batch']
|
||||
__all__ = ['CausalLMOutputWithCrossAttentions', 'ValueHead', 'GPT2HeadWithValueModel', 'respond_to_batch']
|
||||
|
||||
# Cell
|
||||
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel
|
||||
from transformers import top_k_top_p_filtering
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
from torch import nn
|
||||
from torch.nn import Identity
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
# Cell
|
||||
@dataclass
|
||||
class CausalLMOutputWithCrossAttentions(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
value: Optional[torch.FloatTensor] = None
|
||||
|
||||
# Cell
|
||||
|
||||
@ -87,8 +101,11 @@ class GPT2HeadWithValueModel(GPT2PreTrainedModel):
|
||||
mc_token_ids=None,
|
||||
lm_labels=None,
|
||||
mc_labels=None,
|
||||
return_dict=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
):
|
||||
|
||||
loss=None
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@ -104,8 +121,20 @@ class GPT2HeadWithValueModel(GPT2PreTrainedModel):
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
value = self.v_head(hidden_states).squeeze(-1)
|
||||
|
||||
outputs = (lm_logits,) + transformer_outputs[1:] + (value,)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (lm_logits,) + transformer_outputs[1:] + (value,)
|
||||
return outputs
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
value=value,
|
||||
)
|
||||
return outputs
|
||||
|
||||
# Cell
|
||||
|
110
trl/ppo.py
@ -11,15 +11,18 @@ import collections
|
||||
import time
|
||||
import random
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from .core import (logprobs_from_logits,
|
||||
whiten,
|
||||
clip_by_value,
|
||||
entropy_from_logits,
|
||||
flatten_dict,
|
||||
average_torch_dicts,
|
||||
stats_to_np,
|
||||
stack_dicts,
|
||||
add_suffix)
|
||||
whiten,
|
||||
clip_by_value,
|
||||
entropy_from_logits,
|
||||
flatten_dict,
|
||||
average_torch_dicts,
|
||||
stats_to_np,
|
||||
stack_dicts,
|
||||
add_suffix,
|
||||
WANDB_PADDING)
|
||||
|
||||
# Cell
|
||||
|
||||
@ -72,13 +75,14 @@ class PPOTrainer:
|
||||
"ppo_epochs": 4,
|
||||
}
|
||||
|
||||
def __init__(self, model, ref_model, **ppo_params):
|
||||
def __init__(self, model, ref_model, tokenizer, **ppo_params):
|
||||
"""
|
||||
Initialize PPOTrainer.
|
||||
|
||||
Args:
|
||||
model (torch.model): Hugging Face transformer GPT2 model with value head
|
||||
ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty
|
||||
tokenizer (tokenizer): Hugging Face tokenizer
|
||||
ppo_params (dict or None): PPO parameters for training. Can include following keys:
|
||||
'lr' (float): Adam learning rate, default: 1.41e-5
|
||||
'batch_size' (int): Number of samples per optimisation step, default: 256
|
||||
@ -100,6 +104,9 @@ class PPOTrainer:
|
||||
|
||||
self.ref_model = ref_model
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
|
||||
self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])
|
||||
|
||||
if self.ppo_params['adap_kl_ctrl']:
|
||||
@ -110,32 +117,33 @@ class PPOTrainer:
|
||||
self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef'])
|
||||
|
||||
|
||||
def step(self, query, response, scores):
|
||||
def step(self, queries, responses, scores):
|
||||
"""
|
||||
Run a PPO optimisation step.
|
||||
|
||||
args:
|
||||
query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]
|
||||
response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]
|
||||
scores (torch.tensor): tensor containing the scores, shape [batch_size]
|
||||
queries (List): List of tensors containing the encoded queries, shape [query_length]
|
||||
responses (List): List of tensors containing the encoded responses, shape [response_length]
|
||||
scores (List): tensor containing the scores, shape [batch_size]
|
||||
|
||||
returns:
|
||||
train_stats (dict): a summary of the training statistics
|
||||
"""
|
||||
|
||||
bs = self.ppo_params['batch_size']
|
||||
assert bs == len(queries), f"Batch size ({bs}) does not match number of examples ({len(queries)})"
|
||||
|
||||
timing = dict()
|
||||
t0 = time.time()
|
||||
|
||||
gen_len = response.shape[1]
|
||||
model_input = torch.cat((query, response), axis=1)
|
||||
response_lengths = [len(r) for r in responses]
|
||||
|
||||
t = time.time()
|
||||
logprobs, ref_logprobs, values = self.batched_forward_pass(model_input, gen_len)
|
||||
logprobs, ref_logprobs, values = self.batched_forward_pass(queries, responses)
|
||||
timing['time/ppo/forward_pass'] = time.time()-t
|
||||
|
||||
t = time.time()
|
||||
rewards, non_score_reward, kl_coef = self.compute_rewards(scores, logprobs, ref_logprobs)
|
||||
rewards, non_score_reward = self.compute_rewards(scores, logprobs, ref_logprobs)
|
||||
timing['time/ppo/compute_rewards'] = time.time()-t
|
||||
|
||||
t = time.time()
|
||||
@ -145,9 +153,10 @@ class PPOTrainer:
|
||||
random.shuffle(idxs)
|
||||
for i in range(bs):
|
||||
idx = idxs[i]
|
||||
train_stats = self.train_minibatch(logprobs[idx:idx+1], values[idx:idx+1],
|
||||
rewards[idx:idx+1], query[idx:idx+1],
|
||||
response[idx:idx+1], model_input[idx:idx+1])
|
||||
train_stats = self.train_minibatch(logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0),
|
||||
rewards[idx].unsqueeze(0), queries[idx].unsqueeze(0),
|
||||
responses[idx].unsqueeze(0),
|
||||
torch.cat([queries[idx],responses[idx]]).unsqueeze(0))
|
||||
all_stats.append(train_stats)
|
||||
timing['time/ppo/optimize_step'] = time.time()-t
|
||||
|
||||
@ -156,11 +165,12 @@ class PPOTrainer:
|
||||
|
||||
# reshape advantages/ratios such that they are not averaged.
|
||||
train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0)
|
||||
train_stats['policy/advantages'] = torch.nan_to_num(train_stats['policy/advantages'], WANDB_PADDING)
|
||||
train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0)
|
||||
|
||||
stats = self.record_step_stats(scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs,
|
||||
non_score_reward=non_score_reward, train_stats=train_stats,
|
||||
kl_coef=kl_coef)
|
||||
kl_coef=self.kl_ctl.value)
|
||||
stats = stats_to_np(stats)
|
||||
timing['time/ppo/calc_stats'] = time.time()-t
|
||||
|
||||
@ -170,24 +180,30 @@ class PPOTrainer:
|
||||
stats.update(timing)
|
||||
return stats
|
||||
|
||||
def batched_forward_pass(self, model_input, gen_len):
|
||||
def batched_forward_pass(self, queries, responses):
|
||||
"""Calculate model outputs in multiple batches."""
|
||||
bs = self.ppo_params['batch_size']
|
||||
fbs = self.ppo_params['forward_batch_size']
|
||||
logprobs = []
|
||||
ref_logprobs = []
|
||||
values = []
|
||||
all_logprobs = []
|
||||
all_ref_logprobs = []
|
||||
all_values = []
|
||||
|
||||
for i in range(int(self.ppo_params['batch_size']/fbs)):
|
||||
m_input = model_input[i*fbs:(i+1)*fbs]
|
||||
logits, _, v = self.model(m_input)
|
||||
ref_logits, _, _ = self.ref_model(m_input)
|
||||
|
||||
values.append(v[:, -gen_len-1:-1].detach())
|
||||
logprobs.append(logprobs_from_logits(logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())
|
||||
ref_logprobs.append(logprobs_from_logits(ref_logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())
|
||||
|
||||
return torch.cat(logprobs), torch.cat(ref_logprobs), torch.cat(values)
|
||||
for i in range(int(bs/fbs)):
|
||||
query_batch = queries[i*fbs:(i+1)*fbs]
|
||||
response_batch = responses[i*fbs:(i+1)*fbs]
|
||||
input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])["input_ids"]
|
||||
with torch.no_grad():
|
||||
logits, _, v = self.model(input_ids)
|
||||
ref_logits, _, _ = self.ref_model(input_ids)
|
||||
logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:])
|
||||
ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:])
|
||||
for j in range(fbs):
|
||||
start = len(query_batch[j])-1
|
||||
end = len(query_batch[j]) + len(response_batch[j])-1
|
||||
all_values.append(v[j, start-1:end-1])
|
||||
all_logprobs.append(logprobs[j, start:end])
|
||||
all_ref_logprobs.append(ref_logprobs[j, start:end])
|
||||
return all_logprobs, all_ref_logprobs, all_values
|
||||
|
||||
def train_minibatch(self, logprobs, values, rewards, query, response, model_input):
|
||||
"""Train one PPO minibatch"""
|
||||
@ -200,11 +216,15 @@ class PPOTrainer:
|
||||
|
||||
def compute_rewards(self, scores, logprobs, ref_logprobs):
|
||||
"""Compute per token rewards from scores and KL-penalty."""
|
||||
kl = logprobs - ref_logprobs
|
||||
non_score_reward = -self.kl_ctl.value * kl
|
||||
rewards = non_score_reward.clone().detach()
|
||||
rewards[:, -1] += scores
|
||||
return rewards, non_score_reward, self.kl_ctl.value
|
||||
rewards, non_score_rewards = [], []
|
||||
for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs):
|
||||
kl = logprob - ref_logprob
|
||||
non_score_reward = -self.kl_ctl.value * kl
|
||||
non_score_rewards.append(non_score_reward)
|
||||
reward = non_score_reward.clone()
|
||||
reward[-1] += score
|
||||
rewards.append(reward)
|
||||
return rewards, non_score_rewards
|
||||
|
||||
def loss(self, old_logprobs, values, rewards, query, response, model_input):
|
||||
"""Calculate policy and value losses."""
|
||||
@ -269,13 +289,13 @@ class PPOTrainer:
|
||||
|
||||
def record_step_stats(self, kl_coef, **data):
|
||||
"""Record training step statistics."""
|
||||
kl = data['logprobs'] - data['ref_logprobs']
|
||||
mean_kl = torch.mean(torch.sum(kl, axis=-1))
|
||||
mean_entropy = torch.mean(torch.sum(-data['logprobs'], axis=1))
|
||||
mean_non_score_reward =torch.mean(torch.sum(data['non_score_reward'], axis=1))
|
||||
kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])]
|
||||
mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list]))
|
||||
mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']]))
|
||||
mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']]))
|
||||
stats = {
|
||||
'objective/kl': mean_kl,
|
||||
'objective/kl_dist': kl,
|
||||
'objective/kl_dist': kl_list,
|
||||
'objective/logprobs': data['logprobs'],
|
||||
'objective/ref_logprobs': data['ref_logprobs'],
|
||||
'objective/kl_coef': kl_coef,
|
||||
|