Dynamic input sizes (#35)

* change ppo input from tensor to list of tensors for varying shapes

* update readme example with new input type

* update docs

* add listification of tensors need for new API

* replace nans in tensors for wandb compatibility

* add `listify_batch` helper function for backwards compatibility

* update sentiment example with new api

* update docs

* update library

* ignore wandb artifacts

* update requirements

* run experiment

* replace respond to batch with generate

* add experiment

* update docs

* fix action

* fix action
This commit is contained in:
Leandro von Werra
2022-05-15 18:16:25 +02:00
committed by GitHub
parent 5410be61b4
commit 52910d3bf1
24 changed files with 1356 additions and 977 deletions

View File

@ -30,4 +30,4 @@ jobs:
if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi
- name: Run tests
run: |
nbdev_test_nbs
nbdev_test_nbs --fname 'nbs/[!03|!04|!05|]*.ipynb'

2
.gitignore vendored
View File

@ -139,3 +139,5 @@ checklink/cookies.txt
# .gitconfig is now autogenerated
.gitconfig
nbs/wandb/

View File

@ -3,11 +3,11 @@
## What is it?
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗 Hugging Face ([link](https://github.com/huggingface/transformers)). Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformer`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point only decoder architectures such as GTP2 are implemented.
**Highlights:**
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.
## How it works
@ -29,27 +29,29 @@ This process is illustrated in the sketch below:
### Python package
Install the library with pip:
```bash
pip install trl
```
`pip install trl`
### Repository
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
`git clone https://github.com/lvwerra/trl.git`
`cd tlr/`
`pip install -r requirements.txt`
```bash
git clone https://github.com/lvwerra/trl.git
cd tlr/
pip install -r requirements.txt
```
### Jupyter notebooks
If you run Jupyter notebooks you might need to run the following:
`jupyter nbextension enable --py --sys-prefix widgetsnbextension`
```bash
jupyter nbextension enable --py --sys-prefix widgetsnbextension
```
For Jupyterlab additionally this command:
`jupyter labextension install @jupyter-widgets/jupyterlab-manager`
```bash
jupyter labextension install @jupyter-widgets/jupyterlab-manager
```
## How to use
@ -70,7 +72,7 @@ gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# initialize trainer
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, **ppo_config)
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config)
# encode a query
query_txt = "This morning I went to the "
@ -82,14 +84,14 @@ response_txt = gpt2_tokenizer.decode(response_tensor[0,:])
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = torch.tensor([1.0])
reward = [torch.tensor(1.0)]
# train model with ppo
train_stats = ppo_trainer.step(query_tensor, response_tensor, reward)
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```
### Advanced example: IMDB sentiment
For a detailed example check out the notebook *Tune GPT2 to generate positive reviews*, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
For a detailed example check out the notebook `04-gpt2-sentiment-ppo-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
<div style="text-align: center">
<img src="nbs/images/table_imdb_preview.png" width="800">
@ -104,8 +106,10 @@ This library is built with `nbdev` and as such all the library code as well as e
- `00-core.ipynb`: Contains the utility functions used throughout the library and examples.
- `01-gpt2-with-value-head.ipynb`: Implementation of a `transformer` compatible GPT2 model with an additional value head as well as a function to generate sequences.
- `02-ppo.ipynb`: Implementation of the PPOTrainer used to train language models.
- `03-bert-imdb-training.ipynb`: Training of BERT with `simpletransformers` to classify sentiment on the IMDB dataset.
- `03-bert-imdb-training.ipynb`: Training of DistilBERT to classify sentiment on the IMDB dataset.
- `04-gpt2-sentiment-ppo-training.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.
Currently using `trl==0.0.3`:
- `05-gpt2-sentiment-control.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment.
## References
@ -114,4 +118,4 @@ This library is built with `nbdev` and as such all the library code as well as e
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
### Language models
The language models utilize the `transformer` library by 🤗Hugging Face.
The language models utilize the `transformers` library by 🤗 Hugging Face.

View File

@ -31,6 +31,19 @@ description: "A set of utility functions used throughout the library."
<div class="cell border-box-sizing code_cell rendered">
</div>
{% endraw %}
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h2 id="Constants">Constants<a class="anchor-link" href="#Constants"> </a></h2>
</div>
</div>
</div>
{% raw %}
<div class="cell border-box-sizing code_cell rendered">
</div>
{% endraw %}
@ -66,7 +79,7 @@ description: "A set of utility functions used throughout the library."
<span class="n">results</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">stats_dicts</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
<span class="n">stats_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">d</span><span class="p">[</span><span class="n">k</span><span class="p">])</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">stats_dicts</span><span class="p">]</span>
<span class="n">results</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">stats_list</span><span class="p">)</span>
<span class="n">results</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">pad_sequence</span><span class="p">(</span><span class="n">stats_list</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="n">WANDB_PADDING</span><span class="p">)</span>
<span class="k">return</span> <span class="n">results</span>
<span class="k">def</span> <span class="nf">add_suffix</span><span class="p">(</span><span class="n">input_dict</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span>
@ -92,7 +105,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L14" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
<h4 id="flatten_dict" class="doc_header"><code>flatten_dict</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L20" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>flatten_dict</code>(<strong><code>nested</code></strong>, <strong><code>sep</code></strong>=<em><code>'/'</code></em>)</p>
</blockquote>
<p>Flatten dictionary and concatenate nested keys with separator.</p>
@ -117,7 +130,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L28" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
<h4 id="stack_dicts" class="doc_header"><code>stack_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L34" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stack_dicts</code>(<strong><code>stats_dicts</code></strong>)</p>
</blockquote>
<p>Stack the values of a dict.</p>
@ -142,7 +155,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L36" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
<h4 id="add_suffix" class="doc_header"><code>add_suffix</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>add_suffix</code>(<strong><code>input_dict</code></strong>, <strong><code>suffix</code></strong>)</p>
</blockquote>
<p>Add suffix to dict keys.</p>
@ -227,6 +240,10 @@ description: "A set of utility functions used throughout the library."
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]):</span>
<span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">new_dict</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
<span class="k">return</span> <span class="n">new_dict</span>
<span class="k">def</span> <span class="nf">listify_batch</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Turns the first dimension of a tensor into a list.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="p">[</span><span class="n">tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
</pre></div>
</div>
@ -247,7 +264,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L42" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
<h4 id="pad_to_size" class="doc_header"><code>pad_to_size</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L48" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>pad_to_size</code>(<strong><code>tensor</code></strong>, <strong><code>size</code></strong>, <strong><code>dim</code></strong>=<em><code>1</code></em>, <strong><code>padding</code></strong>=<em><code>50256</code></em>)</p>
</blockquote>
<p>Pad tensor to size.</p>
@ -272,7 +289,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L50" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
<h4 id="logprobs_from_logits" class="doc_header"><code>logprobs_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L56" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>logprobs_from_logits</code>(<strong><code>logits</code></strong>, <strong><code>labels</code></strong>)</p>
</blockquote>
<p>See: <a href="https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591">https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591</a></p>
@ -297,7 +314,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L59" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
<h4 id="whiten" class="doc_header"><code>whiten</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L65" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>whiten</code>(<strong><code>values</code></strong>, <strong><code>shift_mean</code></strong>=<em><code>True</code></em>)</p>
</blockquote>
<p>Whiten values.</p>
@ -322,7 +339,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L67" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
<h4 id="clip_by_value" class="doc_header"><code>clip_by_value</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L73" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>clip_by_value</code>(<strong><code>x</code></strong>, <strong><code>tensor_min</code></strong>, <strong><code>tensor_max</code></strong>)</p>
</blockquote>
<p>Tensor extenstion to torch.clamp
<a href="https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713">https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713</a></p>
@ -348,7 +365,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L75" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
<h4 id="entropy_from_logits" class="doc_header"><code>entropy_from_logits</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L81" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>entropy_from_logits</code>(<strong><code>logits</code></strong>)</p>
</blockquote>
<p>Calculate entropy from logits.</p>
@ -373,7 +390,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L82" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
<h4 id="average_torch_dicts" class="doc_header"><code>average_torch_dicts</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L88" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>average_torch_dicts</code>(<strong><code>list_of_dicts</code></strong>)</p>
</blockquote>
<p>Average values of a list of dicts wiht torch tensors.</p>
@ -398,7 +415,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L89" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
<h4 id="stats_to_np" class="doc_header"><code>stats_to_np</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L95" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>stats_to_np</code>(<strong><code>stats_dict</code></strong>)</p>
</blockquote>
<p>Cast all torch.tensors in dict to numpy arrays.</p>
@ -409,6 +426,31 @@ description: "A set of utility functions used throughout the library."
</div>
</div>
</div>
{% endraw %}
{% raw %}
<div class="cell border-box-sizing code_cell rendered">
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="output_markdown rendered_html output_subarea ">
<h4 id="listify_batch" class="doc_header"><code>listify_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L107" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>listify_batch</code>(<strong><code>tensor</code></strong>)</p>
</blockquote>
<p>Turns the first dimension of a tensor into a list.</p>
</div>
</div>
</div>
</div>
</div>
{% endraw %}
@ -468,7 +510,7 @@ description: "A set of utility functions used throughout the library."
<div class="output_markdown rendered_html output_subarea ">
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L104" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
<h4 id="build_bert_batch_from_txt" class="doc_header"><code>build_bert_batch_from_txt</code><a href="https://github.com/lvwerra/trl/tree/master/trl/core.py#L113" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>build_bert_batch_from_txt</code>(<strong><code>text_list</code></strong>, <strong><code>tokenizer</code></strong>, <strong><code>device</code></strong>)</p>
</blockquote>
<p>Create token id and attention mask tensors from text list for BERT classification.</p>

View File

@ -53,6 +53,56 @@ description: "A GPT2 model with a value head built on the `transformer` library
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="nd">@dataclass</span>
<span class="k">class</span> <span class="nc">CausalLMOutputWithCrossAttentions</span><span class="p">(</span><span class="n">ModelOutput</span><span class="p">):</span>
<span class="n">loss</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">past_key_values</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">attentions</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">cross_attentions</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
</pre></div>
</div>
</div>
</div>
</div>
{% endraw %}
{% raw %}
<div class="cell border-box-sizing code_cell rendered">
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="output_markdown rendered_html output_subarea ">
<h2 id="CausalLMOutputWithCrossAttentions" class="doc_header"><code>class</code> <code>CausalLMOutputWithCrossAttentions</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L18" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CausalLMOutputWithCrossAttentions</code>(<strong><code>loss</code></strong>:<code>Optional</code>[<code>FloatTensor</code>]=<em><code>None</code></em>, <strong><code>logits</code></strong>:<code>FloatTensor</code>=<em><code>None</code></em>, <strong><code>past_key_values</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>Tuple</code>[<code>FloatTensor</code>]]]=<em><code>None</code></em>, <strong><code>hidden_states</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>FloatTensor</code>]]=<em><code>None</code></em>, <strong><code>attentions</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>FloatTensor</code>]]=<em><code>None</code></em>, <strong><code>cross_attentions</code></strong>:<code>Optional</code>[<code>Tuple</code>[<code>FloatTensor</code>]]=<em><code>None</code></em>, <strong><code>value</code></strong>:<code>Optional</code>[<code>FloatTensor</code>]=<em><code>None</code></em>) :: <code>ModelOutput</code></p>
</blockquote>
<p>CausalLMOutputWithCrossAttentions(loss: Optional[torch.FloatTensor] = None, logits: torch.FloatTensor = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, hidden_states: Optional[Tuple[torch.FloatTensor]] = None, attentions: Optional[Tuple[torch.FloatTensor]] = None, cross_attentions: Optional[Tuple[torch.FloatTensor]] = None, value: Optional[torch.FloatTensor] = None)</p>
</div>
</div>
</div>
</div>
</div>
{% endraw %}
{% raw %}
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="k">class</span> <span class="nc">ValueHead</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
@ -117,7 +167,7 @@ description: "A GPT2 model with a value head built on the `transformer` library
<div class="output_markdown rendered_html output_subarea ">
<h2 id="ValueHead" class="doc_header"><code>class</code> <code>ValueHead</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L16" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ValueHead</code>(<strong><code>config</code></strong>) :: <code>Module</code></p>
<h2 id="ValueHead" class="doc_header"><code>class</code> <code>ValueHead</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L30" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ValueHead</code>(<strong><code>config</code></strong>) :: <code>Module</code></p>
</blockquote>
<p>The ValueHead class implements a head for GPT2 that returns a scalar for each output token.</p>
@ -167,8 +217,11 @@ description: "A GPT2 model with a value head built on the `transformer` library
<span class="n">mc_token_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">lm_labels</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">mc_labels</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">return_dict</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">output_attentions</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">output_hidden_states</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">loss</span><span class="o">=</span><span class="kc">None</span>
<span class="n">transformer_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="p">(</span>
<span class="n">input_ids</span><span class="p">,</span>
<span class="n">past_key_values</span><span class="o">=</span><span class="n">past_key_values</span><span class="p">,</span>
@ -184,8 +237,20 @@ description: "A GPT2 model with a value head built on the `transformer` library
<span class="n">lm_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">v_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">lm_logits</span><span class="p">,)</span> <span class="o">+</span> <span class="n">transformer_outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span><span class="p">,)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">lm_logits</span><span class="p">,)</span> <span class="o">+</span> <span class="n">transformer_outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span><span class="p">,)</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">return</span> <span class="n">CausalLMOutputWithCrossAttentions</span><span class="p">(</span>
<span class="n">loss</span><span class="o">=</span><span class="n">loss</span><span class="p">,</span>
<span class="n">logits</span><span class="o">=</span><span class="n">lm_logits</span><span class="p">,</span>
<span class="n">past_key_values</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">past_key_values</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">hidden_states</span><span class="p">,</span>
<span class="n">attentions</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">attentions</span><span class="p">,</span>
<span class="n">cross_attentions</span><span class="o">=</span><span class="n">transformer_outputs</span><span class="o">.</span><span class="n">cross_attentions</span><span class="p">,</span>
<span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">outputs</span>
</pre></div>
@ -207,7 +272,7 @@ description: "A GPT2 model with a value head built on the `transformer` library
<div class="output_markdown rendered_html output_subarea ">
<h2 id="GPT2HeadWithValueModel" class="doc_header"><code>class</code> <code>GPT2HeadWithValueModel</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L61" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>GPT2HeadWithValueModel</code>(<strong><code>config</code></strong>) :: <code>GPT2PreTrainedModel</code></p>
<h2 id="GPT2HeadWithValueModel" class="doc_header"><code>class</code> <code>GPT2HeadWithValueModel</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L75" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>GPT2HeadWithValueModel</code>(<strong><code>config</code></strong>) :: <code>GPT2PreTrainedModel</code></p>
</blockquote>
<p>The GPT2HeadWithValueModel class implements a GPT2 language model with a secondary, scalar head.</p>
@ -243,6 +308,21 @@ description: "A GPT2 model with a value head built on the `transformer` library
</div>
</div>
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="output_subarea output_stream output_stderr output_text">
<pre>Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: [&#39;h.6.attn.masked_bias&#39;, &#39;h.10.attn.masked_bias&#39;, &#39;h.0.attn.masked_bias&#39;, &#39;h.3.attn.masked_bias&#39;, &#39;h.7.attn.masked_bias&#39;, &#39;h.5.attn.masked_bias&#39;, &#39;h.11.attn.masked_bias&#39;, &#39;h.9.attn.masked_bias&#39;, &#39;h.8.attn.masked_bias&#39;, &#39;lm_head.weight&#39;, &#39;h.4.attn.masked_bias&#39;, &#39;v_head.summary.weight&#39;, &#39;h.2.attn.masked_bias&#39;, &#39;v_head.summary.bias&#39;, &#39;h.1.attn.masked_bias&#39;]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
</pre>
</div>
</div>
</div>
</div>
</div>
{% endraw %}
@ -508,7 +588,7 @@ description: "A GPT2 model with a value head built on the `transformer` library
<div class="output_markdown rendered_html output_subarea ">
<h4 id="respond_to_batch" class="doc_header"><code>respond_to_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L113" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>respond_to_batch</code>(<strong><code>model</code></strong>, <strong><code>queries</code></strong>, <strong><code>txt_len</code></strong>=<em><code>20</code></em>, <strong><code>top_k</code></strong>=<em><code>0</code></em>, <strong><code>top_p</code></strong>=<em><code>1.0</code></em>)</p>
<h4 id="respond_to_batch" class="doc_header"><code>respond_to_batch</code><a href="https://github.com/lvwerra/trl/tree/master/trl/gpt2.py#L142" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>respond_to_batch</code>(<strong><code>model</code></strong>, <strong><code>queries</code></strong>, <strong><code>txt_len</code></strong>=<em><code>20</code></em>, <strong><code>top_k</code></strong>=<em><code>0</code></em>, <strong><code>top_p</code></strong>=<em><code>1.0</code></em>)</p>
</blockquote>
<p>Sample text from language model.</p>

View File

@ -91,7 +91,7 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<div class="output_markdown rendered_html output_subarea ">
<h2 id="AdaptiveKLController" class="doc_header"><code>class</code> <code>AdaptiveKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L26" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>AdaptiveKLController</code>(<strong><code>init_kl_coef</code></strong>, <strong><code>target</code></strong>, <strong><code>horizon</code></strong>)</p>
<h2 id="AdaptiveKLController" class="doc_header"><code>class</code> <code>AdaptiveKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L29" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>AdaptiveKLController</code>(<strong><code>init_kl_coef</code></strong>, <strong><code>target</code></strong>, <strong><code>horizon</code></strong>)</p>
</blockquote>
<p>Adaptive KL controller described in the paper:
<a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a></p>
@ -140,7 +140,7 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<div class="output_markdown rendered_html output_subarea ">
<h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L44" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>FixedKLController</code>(<strong><code>kl_coef</code></strong>)</p>
<h2 id="FixedKLController" class="doc_header"><code>class</code> <code>FixedKLController</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L47" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>FixedKLController</code>(<strong><code>kl_coef</code></strong>)</p>
</blockquote>
<p>Fixed KL controller.</p>
@ -182,13 +182,14 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="s2">&quot;ppo_epochs&quot;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">ref_model</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_params</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">ref_model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_params</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Initialize PPOTrainer.</span>
<span class="sd"> </span>
<span class="sd"> Args:</span>
<span class="sd"> model (torch.model): Hugging Face transformer GPT2 model with value head</span>
<span class="sd"> ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty</span>
<span class="sd"> tokenizer (tokenizer): Hugging Face tokenizer</span>
<span class="sd"> ppo_params (dict or None): PPO parameters for training. Can include following keys:</span>
<span class="sd"> &#39;lr&#39; (float): Adam learning rate, default: 1.41e-5</span>
<span class="sd"> &#39;batch_size&#39; (int): Number of samples per optimisation step, default: 256</span>
@ -210,6 +211,9 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="bp">self</span><span class="o">.</span><span class="n">ref_model</span> <span class="o">=</span> <span class="n">ref_model</span>
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">tokenizer</span>
<span class="bp">self</span><span class="o">.</span><span class="n">data_collator</span> <span class="o">=</span> <span class="n">DataCollatorForLanguageModeling</span><span class="p">(</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">mlm</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;adap_kl_ctrl&#39;</span><span class="p">]:</span>
@ -220,32 +224,33 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span> <span class="o">=</span> <span class="n">FixedKLController</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;init_kl_coef&#39;</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">,</span> <span class="n">scores</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">queries</span><span class="p">,</span> <span class="n">responses</span><span class="p">,</span> <span class="n">scores</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Run a PPO optimisation step.</span>
<span class="sd"> </span>
<span class="sd"> args:</span>
<span class="sd"> query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]</span>
<span class="sd"> response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]</span>
<span class="sd"> scores (torch.tensor): tensor containing the scores, shape [batch_size]</span>
<span class="sd"> queries (List): List of tensors containing the encoded queries, shape [query_length]</span>
<span class="sd"> responses (List): List of tensors containing the encoded responses, shape [response_length]</span>
<span class="sd"> scores (List): tensor containing the scores, shape [batch_size]</span>
<span class="sd"> </span>
<span class="sd"> returns:</span>
<span class="sd"> train_stats (dict): a summary of the training statistics</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">bs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;batch_size&#39;</span><span class="p">]</span>
<span class="k">assert</span> <span class="n">bs</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">queries</span><span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;Batch size (</span><span class="si">{</span><span class="n">bs</span><span class="si">}</span><span class="s2">) does not match number of examples (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">queries</span><span class="p">)</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="n">timing</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">t0</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">gen_len</span> <span class="o">=</span> <span class="n">response</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">model_input</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">response_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">r</span><span class="p">)</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">responses</span><span class="p">]</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">,</span> <span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batched_forward_pass</span><span class="p">(</span><span class="n">model_input</span><span class="p">,</span> <span class="n">gen_len</span><span class="p">)</span>
<span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">,</span> <span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batched_forward_pass</span><span class="p">(</span><span class="n">queries</span><span class="p">,</span> <span class="n">responses</span><span class="p">)</span>
<span class="n">timing</span><span class="p">[</span><span class="s1">&#39;time/ppo/forward_pass&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_reward</span><span class="p">,</span> <span class="n">kl_coef</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_rewards</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">)</span>
<span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_rewards</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">)</span>
<span class="n">timing</span><span class="p">[</span><span class="s1">&#39;time/ppo/compute_rewards&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
@ -255,9 +260,10 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">idxs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">bs</span><span class="p">):</span>
<span class="n">idx</span> <span class="o">=</span> <span class="n">idxs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">train_stats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_minibatch</span><span class="p">(</span><span class="n">logprobs</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">values</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
<span class="n">rewards</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">query</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
<span class="n">response</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">model_input</span><span class="p">[</span><span class="n">idx</span><span class="p">:</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">])</span>
<span class="n">train_stats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_minibatch</span><span class="p">(</span><span class="n">logprobs</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">values</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">rewards</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">queries</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">responses</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">queries</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span><span class="n">responses</span><span class="p">[</span><span class="n">idx</span><span class="p">]])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">all_stats</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">train_stats</span><span class="p">)</span>
<span class="n">timing</span><span class="p">[</span><span class="s1">&#39;time/ppo/optimize_step&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
@ -266,11 +272,12 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="c1"># reshape advantages/ratios such that they are not averaged.</span>
<span class="n">train_stats</span><span class="p">[</span><span class="s1">&#39;policy/advantages&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">train_stats</span><span class="p">[</span><span class="s1">&#39;policy/advantages&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">train_stats</span><span class="p">[</span><span class="s1">&#39;policy/advantages&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nan_to_num</span><span class="p">(</span><span class="n">train_stats</span><span class="p">[</span><span class="s1">&#39;policy/advantages&#39;</span><span class="p">],</span> <span class="n">WANDB_PADDING</span><span class="p">)</span>
<span class="n">train_stats</span><span class="p">[</span><span class="s1">&#39;policy/ratio&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">train_stats</span><span class="p">[</span><span class="s1">&#39;policy/ratio&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">stats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">record_step_stats</span><span class="p">(</span><span class="n">scores</span><span class="o">=</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="o">=</span><span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="o">=</span><span class="n">ref_logprobs</span><span class="p">,</span>
<span class="n">non_score_reward</span><span class="o">=</span><span class="n">non_score_reward</span><span class="p">,</span> <span class="n">train_stats</span><span class="o">=</span><span class="n">train_stats</span><span class="p">,</span>
<span class="n">kl_coef</span><span class="o">=</span><span class="n">kl_coef</span><span class="p">)</span>
<span class="n">kl_coef</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span><span class="p">)</span>
<span class="n">stats</span> <span class="o">=</span> <span class="n">stats_to_np</span><span class="p">(</span><span class="n">stats</span><span class="p">)</span>
<span class="n">timing</span><span class="p">[</span><span class="s1">&#39;time/ppo/calc_stats&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">t</span>
@ -280,24 +287,30 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="n">stats</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">timing</span><span class="p">)</span>
<span class="k">return</span> <span class="n">stats</span>
<span class="k">def</span> <span class="nf">batched_forward_pass</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_input</span><span class="p">,</span> <span class="n">gen_len</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">batched_forward_pass</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">queries</span><span class="p">,</span> <span class="n">responses</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Calculate model outputs in multiple batches.&quot;&quot;&quot;</span>
<span class="n">bs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;batch_size&#39;</span><span class="p">]</span>
<span class="n">fbs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;forward_batch_size&#39;</span><span class="p">]</span>
<span class="n">logprobs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">ref_logprobs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">values</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">all_logprobs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">all_ref_logprobs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">all_values</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;batch_size&#39;</span><span class="p">]</span><span class="o">/</span><span class="n">fbs</span><span class="p">)):</span>
<span class="n">m_input</span> <span class="o">=</span> <span class="n">model_input</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">fbs</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">fbs</span><span class="p">]</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">m_input</span><span class="p">)</span>
<span class="n">ref_logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ref_model</span><span class="p">(</span><span class="n">m_input</span><span class="p">)</span>
<span class="n">values</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">v</span><span class="p">[:,</span> <span class="o">-</span><span class="n">gen_len</span><span class="o">-</span><span class="mi">1</span><span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="n">logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">m_input</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])[:,</span> <span class="o">-</span><span class="n">gen_len</span><span class="p">:]</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="n">ref_logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">ref_logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">m_input</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])[:,</span> <span class="o">-</span><span class="n">gen_len</span><span class="p">:]</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">logprobs</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">ref_logprobs</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">values</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">bs</span><span class="o">/</span><span class="n">fbs</span><span class="p">)):</span>
<span class="n">query_batch</span> <span class="o">=</span> <span class="n">queries</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">fbs</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">fbs</span><span class="p">]</span>
<span class="n">response_batch</span> <span class="o">=</span> <span class="n">responses</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">fbs</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">fbs</span><span class="p">]</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_collator</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">q</span><span class="p">,</span> <span class="n">r</span><span class="p">])</span> <span class="k">for</span> <span class="n">q</span><span class="p">,</span> <span class="n">r</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">query_batch</span><span class="p">,</span> <span class="n">response_batch</span><span class="p">)])[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">]</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
<span class="n">ref_logits</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ref_model</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
<span class="n">logprobs</span> <span class="o">=</span> <span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">input_ids</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])</span>
<span class="n">ref_logprobs</span> <span class="o">=</span> <span class="n">logprobs_from_logits</span><span class="p">(</span><span class="n">ref_logits</span><span class="p">[:,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:],</span> <span class="n">input_ids</span><span class="p">[:,</span><span class="mi">1</span><span class="p">:])</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">fbs</span><span class="p">):</span>
<span class="n">start</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">query_batch</span><span class="p">[</span><span class="n">j</span><span class="p">])</span><span class="o">-</span><span class="mi">1</span>
<span class="n">end</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">query_batch</span><span class="p">[</span><span class="n">j</span><span class="p">])</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">response_batch</span><span class="p">[</span><span class="n">j</span><span class="p">])</span><span class="o">-</span><span class="mi">1</span>
<span class="n">all_values</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">v</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">start</span><span class="o">-</span><span class="mi">1</span><span class="p">:</span><span class="n">end</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">all_logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">logprobs</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">])</span>
<span class="n">all_ref_logprobs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ref_logprobs</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">])</span>
<span class="k">return</span> <span class="n">all_logprobs</span><span class="p">,</span> <span class="n">all_ref_logprobs</span><span class="p">,</span> <span class="n">all_values</span>
<span class="k">def</span> <span class="nf">train_minibatch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">,</span> <span class="n">model_input</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Train one PPO minibatch&quot;&quot;&quot;</span>
@ -310,18 +323,22 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="k">def</span> <span class="nf">compute_rewards</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Compute per token rewards from scores and KL-penalty.&quot;&quot;&quot;</span>
<span class="n">kl</span> <span class="o">=</span> <span class="n">logprobs</span> <span class="o">-</span> <span class="n">ref_logprobs</span>
<span class="n">non_score_reward</span> <span class="o">=</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span> <span class="o">*</span> <span class="n">kl</span>
<span class="n">rewards</span> <span class="o">=</span> <span class="n">non_score_reward</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="n">rewards</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+=</span> <span class="n">scores</span>
<span class="k">return</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_reward</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span>
<span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_rewards</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">score</span><span class="p">,</span> <span class="n">logprob</span><span class="p">,</span> <span class="n">ref_logprob</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span><span class="p">):</span>
<span class="n">kl</span> <span class="o">=</span> <span class="n">logprob</span> <span class="o">-</span> <span class="n">ref_logprob</span>
<span class="n">non_score_reward</span> <span class="o">=</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">kl_ctl</span><span class="o">.</span><span class="n">value</span> <span class="o">*</span> <span class="n">kl</span>
<span class="n">non_score_rewards</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">non_score_reward</span><span class="p">)</span>
<span class="n">reward</span> <span class="o">=</span> <span class="n">non_score_reward</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">reward</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+=</span> <span class="n">score</span>
<span class="n">rewards</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">reward</span><span class="p">)</span>
<span class="k">return</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">non_score_rewards</span>
<span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">old_logprobs</span><span class="p">,</span> <span class="n">values</span><span class="p">,</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">response</span><span class="p">,</span> <span class="n">model_input</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Calculate policy and value losses.&quot;&quot;&quot;</span>
<span class="n">lastgaelam</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">advantages_reversed</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">gen_len</span> <span class="o">=</span> <span class="n">response</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">gen_len</span><span class="p">)):</span>
<span class="n">nextvalues</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="k">if</span> <span class="n">t</span> <span class="o">&lt;</span> <span class="n">gen_len</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">else</span> <span class="mf">0.0</span>
<span class="n">delta</span> <span class="o">=</span> <span class="n">rewards</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_params</span><span class="p">[</span><span class="s1">&#39;gamma&#39;</span><span class="p">]</span> <span class="o">*</span> <span class="n">nextvalues</span> <span class="o">-</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
@ -379,13 +396,13 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<span class="k">def</span> <span class="nf">record_step_stats</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kl_coef</span><span class="p">,</span> <span class="o">**</span><span class="n">data</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Record training step statistics.&quot;&quot;&quot;</span>
<span class="n">kl</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;logprobs&#39;</span><span class="p">]</span> <span class="o">-</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;ref_logprobs&#39;</span><span class="p">]</span>
<span class="n">mean_kl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">kl</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">mean_entropy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;logprobs&#39;</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">mean_non_score_reward</span> <span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;non_score_reward&#39;</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">kl_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">logprobs</span><span class="o">-</span><span class="n">ref_logprobs</span> <span class="k">for</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">ref_logprobs</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;logprobs&#39;</span><span class="p">],</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;ref_logprobs&#39;</span><span class="p">])]</span>
<span class="n">mean_kl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">kl</span><span class="p">)</span> <span class="k">for</span> <span class="n">kl</span> <span class="ow">in</span> <span class="n">kl_list</span><span class="p">]))</span>
<span class="n">mean_entropy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="n">log_probs</span><span class="p">)</span> <span class="k">for</span> <span class="n">log_probs</span> <span class="ow">in</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;logprobs&#39;</span><span class="p">]]))</span>
<span class="n">mean_non_score_reward</span> <span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">non_score_reward</span><span class="p">)</span> <span class="k">for</span> <span class="n">non_score_reward</span> <span class="ow">in</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;non_score_reward&#39;</span><span class="p">]]))</span>
<span class="n">stats</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;objective/kl&#39;</span><span class="p">:</span> <span class="n">mean_kl</span><span class="p">,</span>
<span class="s1">&#39;objective/kl_dist&#39;</span><span class="p">:</span> <span class="n">kl</span><span class="p">,</span>
<span class="s1">&#39;objective/kl_dist&#39;</span><span class="p">:</span> <span class="n">kl_list</span><span class="p">,</span>
<span class="s1">&#39;objective/logprobs&#39;</span><span class="p">:</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;logprobs&#39;</span><span class="p">],</span>
<span class="s1">&#39;objective/ref_logprobs&#39;</span><span class="p">:</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;ref_logprobs&#39;</span><span class="p">],</span>
<span class="s1">&#39;objective/kl_coef&#39;</span><span class="p">:</span> <span class="n">kl_coef</span><span class="p">,</span>
@ -417,7 +434,7 @@ description: "A Pytorch implementation of Proximal Policy Optimization for trans
<div class="output_markdown rendered_html output_subarea ">
<h2 id="PPOTrainer" class="doc_header"><code>class</code> <code>PPOTrainer</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L54" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>PPOTrainer</code>(<strong><code>model</code></strong>, <strong><code>ref_model</code></strong>, <strong>**<code>ppo_params</code></strong>)</p>
<h2 id="PPOTrainer" class="doc_header"><code>class</code> <code>PPOTrainer</code><a href="https://github.com/lvwerra/trl/tree/master/trl/ppo.py#L57" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>PPOTrainer</code>(<strong><code>model</code></strong>, <strong><code>ref_model</code></strong>, <strong><code>tokenizer</code></strong>, <strong>**<code>ppo_params</code></strong>)</p>
</blockquote>
<p>The PPO_trainer uses Proximal Policy Optimization to optimise language models.</p>

File diff suppressed because it is too large Load Diff

View File

@ -27,6 +27,13 @@ description: "Optimise GPT2 to produce IMDB movie reviews with controlled sentim
</div>
{% endraw %}
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>{% include warning.html content='This notebook uses version <code>trl==0.0.3</code>.' %}</p>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<div style="text-align: center">

Binary file not shown.

Before

Width:  |  Height:  |  Size: 861 KiB

After

Width:  |  Height:  |  Size: 737 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 355 KiB

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -29,11 +29,11 @@ description: "Train transformer language models with reinforcement learning."
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h2 id="What-is-it?">What is it?<a class="anchor-link" href="#What-is-it?"> </a></h2><p>With <code>trl</code> you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the <code>transformer</code> library by 🤗 Hugging Face (<a href="https://github.com/huggingface/transformers">link</a>). Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.</p>
<h2 id="What-is-it?">What is it?<a class="anchor-link" href="#What-is-it?"> </a></h2><p>With <code>trl</code> you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the <code>transformer</code> library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via <code>transformers</code>. At this point only decoder architectures such as GTP2 are implemented.</p>
<p><strong>Highlights:</strong></p>
<ul>
<li>GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.</li>
<li>PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.</li>
<li>GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.</li>
<li>Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.</li>
</ul>
@ -65,15 +65,19 @@ description: "Train transformer language models with reinforcement learning."
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h3 id="Python-package">Python package<a class="anchor-link" href="#Python-package"> </a></h3><p>Install the library with pip:</p>
<p><code>pip install trl</code></p>
<h3 id="Repository">Repository<a class="anchor-link" href="#Repository"> </a></h3><p>If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:</p>
<p><code>git clone https://github.com/lvwerra/trl.git</code></p>
<p><code>cd tlr/</code></p>
<p><code>pip install -r requirements.txt</code></p>
<div class="highlight"><pre><span></span>pip install trl
</pre></div>
<h3 id="From-source">From source<a class="anchor-link" href="#From-source"> </a></h3><p>If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:</p>
<div class="highlight"><pre><span></span>git clone https://github.com/lvwerra/trl.git
<span class="nb">cd</span> tlr/
pip install -r requirements.txt
</pre></div>
<h3 id="Jupyter-notebooks">Jupyter notebooks<a class="anchor-link" href="#Jupyter-notebooks"> </a></h3><p>If you run Jupyter notebooks you might need to run the following:</p>
<p><code>jupyter nbextension enable --py --sys-prefix widgetsnbextension</code></p>
<div class="highlight"><pre><span></span>jupyter nbextension <span class="nb">enable</span> --py --sys-prefix widgetsnbextension
</pre></div>
<p>For Jupyterlab additionally this command:</p>
<p><code>jupyter labextension install @jupyter-widgets/jupyterlab-manager</code></p>
<div class="highlight"><pre><span></span>jupyter labextension install @jupyter-widgets/jupyterlab-manager
</pre></div>
</div>
</div>
@ -111,7 +115,7 @@ description: "Train transformer language models with reinforcement learning."
<span class="c1"># initialize trainer</span>
<span class="n">ppo_config</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;forward_batch_size&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">}</span>
<span class="n">ppo_trainer</span> <span class="o">=</span> <span class="n">PPOTrainer</span><span class="p">(</span><span class="n">gpt2_model</span><span class="p">,</span> <span class="n">gpt2_model_ref</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_config</span><span class="p">)</span>
<span class="n">ppo_trainer</span> <span class="o">=</span> <span class="n">PPOTrainer</span><span class="p">(</span><span class="n">gpt2_model</span><span class="p">,</span> <span class="n">gpt2_model_ref</span><span class="p">,</span> <span class="n">gpt2_tokenizer</span><span class="p">,</span> <span class="o">**</span><span class="n">ppo_config</span><span class="p">)</span>
<span class="c1"># encode a query</span>
<span class="n">query_txt</span> <span class="o">=</span> <span class="s2">&quot;This morning I went to the &quot;</span>
@ -123,10 +127,10 @@ description: "Train transformer language models with reinforcement learning."
<span class="c1"># define a reward for response</span>
<span class="c1"># (this could be any reward such as human feedback or output from another model)</span>
<span class="n">reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.0</span><span class="p">])</span>
<span class="n">reward</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)]</span>
<span class="c1"># train model with ppo</span>
<span class="n">train_stats</span> <span class="o">=</span> <span class="n">ppo_trainer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">query_tensor</span><span class="p">,</span> <span class="n">response_tensor</span><span class="p">,</span> <span class="n">reward</span><span class="p">)</span>
<span class="n">train_stats</span> <span class="o">=</span> <span class="n">ppo_trainer</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="n">query_tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="p">[</span><span class="n">response_tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="n">reward</span><span class="p">)</span>
</pre></div>
</div>
@ -139,9 +143,9 @@ description: "Train transformer language models with reinforcement learning."
<div class="output_area">
<div class="output_subarea output_stream output_stderr output_text">
<pre>Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: [&#39;h.9.attn.masked_bias&#39;, &#39;h.8.attn.masked_bias&#39;, &#39;lm_head.weight&#39;, &#39;h.4.attn.masked_bias&#39;, &#39;v_head.summary.weight&#39;, &#39;h.10.attn.masked_bias&#39;, &#39;h.0.attn.masked_bias&#39;, &#39;h.5.attn.masked_bias&#39;, &#39;h.7.attn.masked_bias&#39;, &#39;h.6.attn.masked_bias&#39;, &#39;h.1.attn.masked_bias&#39;, &#39;v_head.summary.bias&#39;, &#39;h.11.attn.masked_bias&#39;, &#39;h.2.attn.masked_bias&#39;, &#39;h.3.attn.masked_bias&#39;]
<pre>Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: [&#39;v_head.summary.weight&#39;, &#39;h.8.attn.masked_bias&#39;, &#39;h.0.attn.masked_bias&#39;, &#39;h.9.attn.masked_bias&#39;, &#39;h.1.attn.masked_bias&#39;, &#39;h.6.attn.masked_bias&#39;, &#39;h.5.attn.masked_bias&#39;, &#39;h.3.attn.masked_bias&#39;, &#39;lm_head.weight&#39;, &#39;v_head.summary.bias&#39;, &#39;h.2.attn.masked_bias&#39;, &#39;h.11.attn.masked_bias&#39;, &#39;h.7.attn.masked_bias&#39;, &#39;h.4.attn.masked_bias&#39;, &#39;h.10.attn.masked_bias&#39;]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: [&#39;h.9.attn.masked_bias&#39;, &#39;h.8.attn.masked_bias&#39;, &#39;lm_head.weight&#39;, &#39;h.4.attn.masked_bias&#39;, &#39;v_head.summary.weight&#39;, &#39;h.10.attn.masked_bias&#39;, &#39;h.0.attn.masked_bias&#39;, &#39;h.5.attn.masked_bias&#39;, &#39;h.7.attn.masked_bias&#39;, &#39;h.6.attn.masked_bias&#39;, &#39;h.1.attn.masked_bias&#39;, &#39;v_head.summary.bias&#39;, &#39;h.11.attn.masked_bias&#39;, &#39;h.2.attn.masked_bias&#39;, &#39;h.3.attn.masked_bias&#39;]
Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: [&#39;v_head.summary.weight&#39;, &#39;h.8.attn.masked_bias&#39;, &#39;h.0.attn.masked_bias&#39;, &#39;h.9.attn.masked_bias&#39;, &#39;h.1.attn.masked_bias&#39;, &#39;h.6.attn.masked_bias&#39;, &#39;h.5.attn.masked_bias&#39;, &#39;h.3.attn.masked_bias&#39;, &#39;lm_head.weight&#39;, &#39;v_head.summary.bias&#39;, &#39;h.2.attn.masked_bias&#39;, &#39;h.11.attn.masked_bias&#39;, &#39;h.7.attn.masked_bias&#39;, &#39;h.4.attn.masked_bias&#39;, &#39;h.10.attn.masked_bias&#39;]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
</pre>
</div>
@ -155,7 +159,7 @@ You should probably TRAIN this model on a down-stream task to be able to use it
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h3 id="Advanced-example:-IMDB-sentiment">Advanced example: IMDB sentiment<a class="anchor-link" href="#Advanced-example:-IMDB-sentiment"> </a></h3><p>For a detailed example check out the notebook <em>Tune GPT2 to generate positive reviews</em>, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:</p>
<h3 id="Advanced-example:-IMDB-sentiment">Advanced example: IMDB sentiment<a class="anchor-link" href="#Advanced-example:-IMDB-sentiment"> </a></h3><p>For a detailed example check out the notebook <code>04-gpt2-sentiment-ppo-training.ipynb</code>, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:</p>
<div style="text-align: center">
{% include image.html max-width="800" file="/trl/images/table_imdb_preview.png" %}
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
@ -171,8 +175,11 @@ You should probably TRAIN this model on a down-stream task to be able to use it
<li><code>00-core.ipynb</code>: Contains the utility functions used throughout the library and examples.</li>
<li><code>01-gpt2-with-value-head.ipynb</code>: Implementation of a <code>transformer</code> compatible GPT2 model with an additional value head as well as a function to generate sequences.</li>
<li><code>02-ppo.ipynb</code>: Implementation of the PPOTrainer used to train language models.</li>
<li><code>03-bert-imdb-training.ipynb</code>: Training of BERT with <code>simpletransformers</code> to classify sentiment on the IMDB dataset.</li>
<li><code>03-bert-imdb-training.ipynb</code>: Training of DistilBERT to classify sentiment on the IMDB dataset.</li>
<li><code>04-gpt2-sentiment-ppo-training.ipynb</code>: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.</li>
</ul>
<p>Currently using <code>trl==0.0.3</code>:</p>
<ul>
<li><code>05-gpt2-sentiment-control.ipynb</code>: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment.</li>
</ul>
@ -182,7 +189,7 @@ You should probably TRAIN this model on a down-stream task to be able to use it
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h2 id="References">References<a class="anchor-link" href="#References"> </a></h2><h3 id="Proximal-Policy-Optimisation">Proximal Policy Optimisation<a class="anchor-link" href="#Proximal-Policy-Optimisation"> </a></h3><p>The PPO implementation largely follows the structure introduced in the paper <strong>"Fine-Tuning Language Models from Human Preferences"</strong> by D. Ziegler et al. [<a href="https://arxiv.org/pdf/1909.08593.pdf">paper</a>, <a href="https://github.com/openai/lm-human-preferences">code</a>].</p>
<h3 id="Language-models">Language models<a class="anchor-link" href="#Language-models"> </a></h3><p>The language models utilize the <code>transformer</code> library by 🤗Hugging Face.</p>
<h3 id="Language-models">Language models<a class="anchor-link" href="#Language-models"> </a></h3><p>The language models utilize the <code>transformers</code> library by 🤗 Hugging Face.</p>
</div>
</div>

View File

@ -26,10 +26,29 @@
"# export\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"\n",
"import collections\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Constants"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"WANDB_PADDING = -1"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -64,7 +83,7 @@
" results = dict()\n",
" for k in stats_dicts[0]:\n",
" stats_list = [torch.flatten(d[k]) for d in stats_dicts]\n",
" results[k] = torch.stack(stats_list)\n",
" results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)\n",
" return results\n",
"\n",
"def add_suffix(input_dict, suffix):\n",
@ -144,7 +163,11 @@
" new_dict[k] = v\n",
" if np.isscalar(new_dict[k]):\n",
" new_dict[k] = float(new_dict[k])\n",
" return new_dict\n"
" return new_dict\n",
"\n",
"def listify_batch(tensor):\n",
" \"\"\"Turns the first dimension of a tensor into a list.\"\"\"\n",
" return [tensor[i] for i in range(tensor.shape[0])]"
]
},
{
@ -197,7 +220,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}

View File

@ -43,10 +43,31 @@
"\n",
"from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel\n",
"from transformers import top_k_top_p_filtering\n",
"from transformers.modeling_outputs import ModelOutput\n",
"from torch import nn\n",
"from torch.nn import Identity\n",
"import torch.nn.functional as F\n",
"import torch"
"import torch\n",
"from dataclasses import dataclass\n",
"from typing import Optional, Tuple"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# exports\n",
"@dataclass\n",
"class CausalLMOutputWithCrossAttentions(ModelOutput):\n",
" loss: Optional[torch.FloatTensor] = None\n",
" logits: torch.FloatTensor = None\n",
" past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n",
" hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n",
" attentions: Optional[Tuple[torch.FloatTensor]] = None\n",
" cross_attentions: Optional[Tuple[torch.FloatTensor]] = None\n",
" value: Optional[torch.FloatTensor] = None"
]
},
{
@ -138,8 +159,11 @@
" mc_token_ids=None,\n",
" lm_labels=None,\n",
" mc_labels=None,\n",
" return_dict=False,\n",
" output_attentions=False,\n",
" output_hidden_states=False,\n",
" ):\n",
" \n",
" loss=None\n",
" transformer_outputs = self.transformer(\n",
" input_ids,\n",
" past_key_values=past_key_values,\n",
@ -155,8 +179,20 @@
" lm_logits = self.lm_head(hidden_states)\n",
" value = self.v_head(hidden_states).squeeze(-1)\n",
"\n",
" outputs = (lm_logits,) + transformer_outputs[1:] + (value,)\n",
" \n",
" if not return_dict:\n",
" outputs = (lm_logits,) + transformer_outputs[1:] + (value,)\n",
" return outputs\n",
"\n",
" return CausalLMOutputWithCrossAttentions(\n",
" loss=loss,\n",
" logits=lm_logits,\n",
" past_key_values=transformer_outputs.past_key_values,\n",
" hidden_states=transformer_outputs.hidden_states,\n",
" attentions=transformer_outputs.attentions,\n",
" cross_attentions=transformer_outputs.cross_attentions,\n",
" value=value,\n",
" ) \n",
" return outputs"
]
},
@ -172,7 +208,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.6.attn.masked_bias', 'h.10.attn.masked_bias', 'h.0.attn.masked_bias', 'h.3.attn.masked_bias', 'h.7.attn.masked_bias', 'h.5.attn.masked_bias', 'h.11.attn.masked_bias', 'h.9.attn.masked_bias', 'h.8.attn.masked_bias', 'lm_head.weight', 'h.4.attn.masked_bias', 'v_head.summary.weight', 'h.2.attn.masked_bias', 'v_head.summary.bias', 'h.1.attn.masked_bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = GPT2HeadWithValueModel.from_pretrained('gpt2')\n",
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2')"
@ -464,7 +509,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}

View File

@ -40,15 +40,18 @@
"import time\n",
"import random\n",
"\n",
"from transformers import DataCollatorForLanguageModeling\n",
"\n",
"from trl.core import (logprobs_from_logits,\n",
" whiten,\n",
" clip_by_value,\n",
" entropy_from_logits,\n",
" flatten_dict,\n",
" average_torch_dicts,\n",
" stats_to_np,\n",
" stack_dicts,\n",
" add_suffix)"
" whiten,\n",
" clip_by_value,\n",
" entropy_from_logits,\n",
" flatten_dict,\n",
" average_torch_dicts,\n",
" stats_to_np,\n",
" stack_dicts,\n",
" add_suffix,\n",
" WANDB_PADDING)"
]
},
{
@ -132,13 +135,14 @@
" \"ppo_epochs\": 4, \n",
" } \n",
" \n",
" def __init__(self, model, ref_model, **ppo_params):\n",
" def __init__(self, model, ref_model, tokenizer, **ppo_params):\n",
" \"\"\"\n",
" Initialize PPOTrainer.\n",
" \n",
" Args:\n",
" model (torch.model): Hugging Face transformer GPT2 model with value head\n",
" ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty\n",
" tokenizer (tokenizer): Hugging Face tokenizer\n",
" ppo_params (dict or None): PPO parameters for training. Can include following keys:\n",
" 'lr' (float): Adam learning rate, default: 1.41e-5\n",
" 'batch_size' (int): Number of samples per optimisation step, default: 256\n",
@ -160,6 +164,9 @@
" \n",
" self.ref_model = ref_model\n",
" self.model = model\n",
" self.tokenizer = tokenizer\n",
" self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
" \n",
" self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])\n",
" \n",
" if self.ppo_params['adap_kl_ctrl']:\n",
@ -170,32 +177,33 @@
" self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef'])\n",
"\n",
"\n",
" def step(self, query, response, scores):\n",
" def step(self, queries, responses, scores):\n",
" \"\"\"\n",
" Run a PPO optimisation step.\n",
" \n",
" args:\n",
" query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]\n",
" response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]\n",
" scores (torch.tensor): tensor containing the scores, shape [batch_size]\n",
" queries (List): List of tensors containing the encoded queries, shape [query_length]\n",
" responses (List): List of tensors containing the encoded responses, shape [response_length]\n",
" scores (List): tensor containing the scores, shape [batch_size]\n",
" \n",
" returns:\n",
" train_stats (dict): a summary of the training statistics\n",
" \"\"\"\n",
"\n",
" bs = self.ppo_params['batch_size']\n",
" assert bs == len(queries), f\"Batch size ({bs}) does not match number of examples ({len(queries)})\"\n",
" \n",
" timing = dict()\n",
" t0 = time.time()\n",
" \n",
" gen_len = response.shape[1]\n",
" model_input = torch.cat((query, response), axis=1)\n",
" response_lengths = [len(r) for r in responses]\n",
" \n",
" t = time.time()\n",
" logprobs, ref_logprobs, values = self.batched_forward_pass(model_input, gen_len)\n",
" logprobs, ref_logprobs, values = self.batched_forward_pass(queries, responses)\n",
" timing['time/ppo/forward_pass'] = time.time()-t\n",
"\n",
" t = time.time()\n",
" rewards, non_score_reward, kl_coef = self.compute_rewards(scores, logprobs, ref_logprobs)\n",
" rewards, non_score_reward = self.compute_rewards(scores, logprobs, ref_logprobs)\n",
" timing['time/ppo/compute_rewards'] = time.time()-t \n",
" \n",
" t = time.time() \n",
@ -205,9 +213,10 @@
" random.shuffle(idxs)\n",
" for i in range(bs):\n",
" idx = idxs[i]\n",
" train_stats = self.train_minibatch(logprobs[idx:idx+1], values[idx:idx+1],\n",
" rewards[idx:idx+1], query[idx:idx+1],\n",
" response[idx:idx+1], model_input[idx:idx+1])\n",
" train_stats = self.train_minibatch(logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0),\n",
" rewards[idx].unsqueeze(0), queries[idx].unsqueeze(0),\n",
" responses[idx].unsqueeze(0),\n",
" torch.cat([queries[idx],responses[idx]]).unsqueeze(0))\n",
" all_stats.append(train_stats)\n",
" timing['time/ppo/optimize_step'] = time.time()-t\n",
" \n",
@ -216,11 +225,12 @@
" \n",
" # reshape advantages/ratios such that they are not averaged.\n",
" train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0)\n",
" train_stats['policy/advantages'] = torch.nan_to_num(train_stats['policy/advantages'], WANDB_PADDING)\n",
" train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0)\n",
" \n",
" stats = self.record_step_stats(scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs,\n",
" non_score_reward=non_score_reward, train_stats=train_stats,\n",
" kl_coef=kl_coef)\n",
" kl_coef=self.kl_ctl.value)\n",
" stats = stats_to_np(stats)\n",
" timing['time/ppo/calc_stats'] = time.time()-t\n",
"\n",
@ -230,24 +240,30 @@
" stats.update(timing)\n",
" return stats\n",
"\n",
" def batched_forward_pass(self, model_input, gen_len):\n",
" def batched_forward_pass(self, queries, responses):\n",
" \"\"\"Calculate model outputs in multiple batches.\"\"\"\n",
" bs = self.ppo_params['batch_size']\n",
" fbs = self.ppo_params['forward_batch_size']\n",
" logprobs = []\n",
" ref_logprobs = []\n",
" values = []\n",
" all_logprobs = []\n",
" all_ref_logprobs = []\n",
" all_values = []\n",
" \n",
" for i in range(int(self.ppo_params['batch_size']/fbs)):\n",
" m_input = model_input[i*fbs:(i+1)*fbs]\n",
" logits, _, v = self.model(m_input)\n",
" ref_logits, _, _ = self.ref_model(m_input)\n",
" \n",
" values.append(v[:, -gen_len-1:-1].detach())\n",
" logprobs.append(logprobs_from_logits(logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())\n",
" ref_logprobs.append(logprobs_from_logits(ref_logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())\n",
" \n",
" return torch.cat(logprobs), torch.cat(ref_logprobs), torch.cat(values)\n",
" for i in range(int(bs/fbs)):\n",
" query_batch = queries[i*fbs:(i+1)*fbs]\n",
" response_batch = responses[i*fbs:(i+1)*fbs]\n",
" input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])[\"input_ids\"]\n",
" with torch.no_grad():\n",
" logits, _, v = self.model(input_ids)\n",
" ref_logits, _, _ = self.ref_model(input_ids)\n",
" logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:])\n",
" ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:])\n",
" for j in range(fbs):\n",
" start = len(query_batch[j])-1\n",
" end = len(query_batch[j]) + len(response_batch[j])-1\n",
" all_values.append(v[j, start-1:end-1])\n",
" all_logprobs.append(logprobs[j, start:end])\n",
" all_ref_logprobs.append(ref_logprobs[j, start:end])\n",
" return all_logprobs, all_ref_logprobs, all_values\n",
" \n",
" def train_minibatch(self, logprobs, values, rewards, query, response, model_input):\n",
" \"\"\"Train one PPO minibatch\"\"\"\n",
@ -260,18 +276,22 @@
" \n",
" def compute_rewards(self, scores, logprobs, ref_logprobs):\n",
" \"\"\"Compute per token rewards from scores and KL-penalty.\"\"\"\n",
" kl = logprobs - ref_logprobs\n",
" non_score_reward = -self.kl_ctl.value * kl\n",
" rewards = non_score_reward.clone().detach()\n",
" rewards[:, -1] += scores\n",
" return rewards, non_score_reward, self.kl_ctl.value\n",
" rewards, non_score_rewards = [], []\n",
" for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs):\n",
" kl = logprob - ref_logprob\n",
" non_score_reward = -self.kl_ctl.value * kl\n",
" non_score_rewards.append(non_score_reward)\n",
" reward = non_score_reward.clone()\n",
" reward[-1] += score\n",
" rewards.append(reward)\n",
" return rewards, non_score_rewards\n",
"\n",
" def loss(self, old_logprobs, values, rewards, query, response, model_input):\n",
" \"\"\"Calculate policy and value losses.\"\"\"\n",
" lastgaelam = 0\n",
" advantages_reversed = []\n",
" gen_len = response.shape[1]\n",
"\n",
" \n",
" for t in reversed(range(gen_len)):\n",
" nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0\n",
" delta = rewards[:, t] + self.ppo_params['gamma'] * nextvalues - values[:, t]\n",
@ -329,13 +349,13 @@
"\n",
" def record_step_stats(self, kl_coef, **data):\n",
" \"\"\"Record training step statistics.\"\"\"\n",
" kl = data['logprobs'] - data['ref_logprobs']\n",
" mean_kl = torch.mean(torch.sum(kl, axis=-1))\n",
" mean_entropy = torch.mean(torch.sum(-data['logprobs'], axis=1))\n",
" mean_non_score_reward =torch.mean(torch.sum(data['non_score_reward'], axis=1))\n",
" kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])] \n",
" mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list]))\n",
" mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']]))\n",
" mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']]))\n",
" stats = {\n",
" 'objective/kl': mean_kl,\n",
" 'objective/kl_dist': kl,\n",
" 'objective/kl_dist': kl_list,\n",
" 'objective/logprobs': data['logprobs'],\n",
" 'objective/ref_logprobs': data['ref_logprobs'],\n",
" 'objective/kl_coef': kl_coef,\n",
@ -396,7 +416,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,13 @@
"> Optimise GPT2 to produce IMDB movie reviews with controlled sentiment using a BERT sentiment classifier for rewards."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> warning: This notebook uses version `trl==0.0.3`."
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -1644,18 +1651,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 861 KiB

After

Width:  |  Height:  |  Size: 737 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 355 KiB

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -14,11 +14,11 @@
"metadata": {},
"source": [
"## What is it?\n",
"With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built with the `transformer` library by 🤗 Hugging Face ([link](https://github.com/huggingface/transformers)). Therefore, pre-trained language models can be directly loaded via the transformer interface. At this point only GTP2 is implemented.\n",
"With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformer`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point only decoder architectures such as GTP2 are implemented.\n",
"\n",
"**Highlights:**\n",
"- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.\n",
"- PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.\n",
"- GPT2 model with a value head: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.\n",
"- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier."
]
},
@ -55,27 +55,29 @@
"source": [
"### Python package\n",
"Install the library with pip:\n",
"```bash\n",
"pip install trl\n",
"```\n",
"\n",
"`pip install trl`\n",
"\n",
"### Repository\n",
"### From source\n",
"If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:\n",
"\n",
"`git clone https://github.com/lvwerra/trl.git`\n",
"\n",
"`cd tlr/`\n",
"\n",
"`pip install -r requirements.txt`\n",
"\n",
"```bash\n",
"git clone https://github.com/lvwerra/trl.git\n",
"cd tlr/\n",
"pip install -r requirements.txt\n",
"```\n",
"### Jupyter notebooks\n",
"\n",
"If you run Jupyter notebooks you might need to run the following:\n",
"\n",
"`jupyter nbextension enable --py --sys-prefix widgetsnbextension`\n",
"```bash\n",
"jupyter nbextension enable --py --sys-prefix widgetsnbextension\n",
"```\n",
"\n",
"For Jupyterlab additionally this command:\n",
"\n",
"`jupyter labextension install @jupyter-widgets/jupyterlab-manager`"
"```bash\n",
"jupyter labextension install @jupyter-widgets/jupyterlab-manager\n",
"```"
]
},
{
@ -112,7 +114,7 @@
"\n",
"# initialize trainer\n",
"ppo_config = {'batch_size': 1, 'forward_batch_size': 1}\n",
"ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, **ppo_config)\n",
"ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config)\n",
"\n",
"# encode a query\n",
"query_txt = \"This morning I went to the \"\n",
@ -124,10 +126,10 @@
"\n",
"# define a reward for response\n",
"# (this could be any reward such as human feedback or output from another model)\n",
"reward = torch.tensor([1.0]) \n",
"reward = [torch.tensor(1.0)]\n",
"\n",
"# train model with ppo\n",
"train_stats = ppo_trainer.step(query_tensor, response_tensor, reward)"
"train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)"
]
},
{
@ -135,7 +137,7 @@
"metadata": {},
"source": [
"### Advanced example: IMDB sentiment\n",
"For a detailed example check out the notebook *Tune GPT2 to generate positive reviews*, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:\n",
"For a detailed example check out the notebook `04-gpt2-sentiment-ppo-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:\n",
"\n",
"<div style=\"text-align: center\">\n",
"<img src='images/table_imdb_preview.png' width='800'>\n",
@ -154,8 +156,10 @@
"- `00-core.ipynb`: Contains the utility functions used throughout the library and examples.\n",
"- `01-gpt2-with-value-head.ipynb`: Implementation of a `transformer` compatible GPT2 model with an additional value head as well as a function to generate sequences.\n",
"- `02-ppo.ipynb`: Implementation of the PPOTrainer used to train language models.\n",
"- `03-bert-imdb-training.ipynb`: Training of BERT with `simpletransformers` to classify sentiment on the IMDB dataset.\n",
"- `03-bert-imdb-training.ipynb`: Training of DistilBERT to classify sentiment on the IMDB dataset.\n",
"- `04-gpt2-sentiment-ppo-training.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce positive movie reviews.\n",
"\n",
"Currently using `trl==0.0.3`:\n",
"- `05-gpt2-sentiment-control.ipynb`: Fine-tune GPT2 with the BERT sentiment classifier to produce movie reviews with controlled sentiment."
]
},
@ -169,7 +173,7 @@
"The PPO implementation largely follows the structure introduced in the paper **\"Fine-Tuning Language Models from Human Preferences\"** by D. Ziegler et al. \\[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].\n",
"\n",
"### Language models\n",
"The language models utilize the `transformer` library by 🤗Hugging Face."
"The language models utilize the `transformers` library by 🤗 Hugging Face."
]
},
{
@ -182,7 +186,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}

View File

@ -4,7 +4,7 @@ jupyterlab==2.2.10
nbdev==0.2.16
datasets==1.17.0
torch>=1.4.0
tqdm==4.43.0
tqdm
transformers==4.15.0
wandb==0.10.20
matplotlib==3.5.1

View File

@ -2,7 +2,8 @@
__all__ = ["index", "modules", "custom_doc_links", "git_url"]
index = {"flatten_dict": "00-core.ipynb",
index = {"WANDB_PADDING": "00-core.ipynb",
"flatten_dict": "00-core.ipynb",
"stack_dicts": "00-core.ipynb",
"add_suffix": "00-core.ipynb",
"pad_to_size": "00-core.ipynb",
@ -12,7 +13,9 @@ index = {"flatten_dict": "00-core.ipynb",
"entropy_from_logits": "00-core.ipynb",
"average_torch_dicts": "00-core.ipynb",
"stats_to_np": "00-core.ipynb",
"listify_batch": "00-core.ipynb",
"build_bert_batch_from_txt": "00-core.ipynb",
"CausalLMOutputWithCrossAttentions": "01-gpt2-with-value-head.ipynb",
"ValueHead": "01-gpt2-with-value-head.ipynb",
"GPT2HeadWithValueModel": "01-gpt2-with-value-head.ipynb",
"respond_to_batch": "01-gpt2-with-value-head.ipynb",

View File

@ -1,14 +1,20 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00-core.ipynb (unless otherwise specified).
__all__ = ['flatten_dict', 'stack_dicts', 'add_suffix', 'pad_to_size', 'logprobs_from_logits', 'whiten',
'clip_by_value', 'entropy_from_logits', 'average_torch_dicts', 'stats_to_np', 'build_bert_batch_from_txt']
__all__ = ['WANDB_PADDING', 'flatten_dict', 'stack_dicts', 'add_suffix', 'pad_to_size', 'logprobs_from_logits',
'whiten', 'clip_by_value', 'entropy_from_logits', 'average_torch_dicts', 'stats_to_np', 'listify_batch',
'build_bert_batch_from_txt']
# Cell
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import collections
import numpy as np
# Cell
WANDB_PADDING = -1
# Cell
def flatten_dict(nested, sep='/'):
@ -30,7 +36,7 @@ def stack_dicts(stats_dicts):
results = dict()
for k in stats_dicts[0]:
stats_list = [torch.flatten(d[k]) for d in stats_dicts]
results[k] = torch.stack(stats_list)
results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
return results
def add_suffix(input_dict, suffix):
@ -98,6 +104,9 @@ def stats_to_np(stats_dict):
new_dict[k] = float(new_dict[k])
return new_dict
def listify_batch(tensor):
"""Turns the first dimension of a tensor into a list."""
return [tensor[i] for i in range(tensor.shape[0])]
# Cell

View File

@ -1,15 +1,29 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01-gpt2-with-value-head.ipynb (unless otherwise specified).
__all__ = ['ValueHead', 'GPT2HeadWithValueModel', 'respond_to_batch']
__all__ = ['CausalLMOutputWithCrossAttentions', 'ValueHead', 'GPT2HeadWithValueModel', 'respond_to_batch']
# Cell
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel
from transformers import top_k_top_p_filtering
from transformers.modeling_outputs import ModelOutput
from torch import nn
from torch.nn import Identity
import torch.nn.functional as F
import torch
from dataclasses import dataclass
from typing import Optional, Tuple
# Cell
@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
value: Optional[torch.FloatTensor] = None
# Cell
@ -87,8 +101,11 @@ class GPT2HeadWithValueModel(GPT2PreTrainedModel):
mc_token_ids=None,
lm_labels=None,
mc_labels=None,
return_dict=False,
output_attentions=False,
output_hidden_states=False,
):
loss=None
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
@ -104,8 +121,20 @@ class GPT2HeadWithValueModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states)
value = self.v_head(hidden_states).squeeze(-1)
outputs = (lm_logits,) + transformer_outputs[1:] + (value,)
if not return_dict:
outputs = (lm_logits,) + transformer_outputs[1:] + (value,)
return outputs
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
value=value,
)
return outputs
# Cell

View File

@ -11,15 +11,18 @@ import collections
import time
import random
from transformers import DataCollatorForLanguageModeling
from .core import (logprobs_from_logits,
whiten,
clip_by_value,
entropy_from_logits,
flatten_dict,
average_torch_dicts,
stats_to_np,
stack_dicts,
add_suffix)
whiten,
clip_by_value,
entropy_from_logits,
flatten_dict,
average_torch_dicts,
stats_to_np,
stack_dicts,
add_suffix,
WANDB_PADDING)
# Cell
@ -72,13 +75,14 @@ class PPOTrainer:
"ppo_epochs": 4,
}
def __init__(self, model, ref_model, **ppo_params):
def __init__(self, model, ref_model, tokenizer, **ppo_params):
"""
Initialize PPOTrainer.
Args:
model (torch.model): Hugging Face transformer GPT2 model with value head
ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty
tokenizer (tokenizer): Hugging Face tokenizer
ppo_params (dict or None): PPO parameters for training. Can include following keys:
'lr' (float): Adam learning rate, default: 1.41e-5
'batch_size' (int): Number of samples per optimisation step, default: 256
@ -100,6 +104,9 @@ class PPOTrainer:
self.ref_model = ref_model
self.model = model
self.tokenizer = tokenizer
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr'])
if self.ppo_params['adap_kl_ctrl']:
@ -110,32 +117,33 @@ class PPOTrainer:
self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef'])
def step(self, query, response, scores):
def step(self, queries, responses, scores):
"""
Run a PPO optimisation step.
args:
query (torch.tensor): tensor containing the encoded queries, shape [batch_size, query_length]
response (torch.tensor): tensor containing the encoded responses, shape [batch_size, response_length]
scores (torch.tensor): tensor containing the scores, shape [batch_size]
queries (List): List of tensors containing the encoded queries, shape [query_length]
responses (List): List of tensors containing the encoded responses, shape [response_length]
scores (List): tensor containing the scores, shape [batch_size]
returns:
train_stats (dict): a summary of the training statistics
"""
bs = self.ppo_params['batch_size']
assert bs == len(queries), f"Batch size ({bs}) does not match number of examples ({len(queries)})"
timing = dict()
t0 = time.time()
gen_len = response.shape[1]
model_input = torch.cat((query, response), axis=1)
response_lengths = [len(r) for r in responses]
t = time.time()
logprobs, ref_logprobs, values = self.batched_forward_pass(model_input, gen_len)
logprobs, ref_logprobs, values = self.batched_forward_pass(queries, responses)
timing['time/ppo/forward_pass'] = time.time()-t
t = time.time()
rewards, non_score_reward, kl_coef = self.compute_rewards(scores, logprobs, ref_logprobs)
rewards, non_score_reward = self.compute_rewards(scores, logprobs, ref_logprobs)
timing['time/ppo/compute_rewards'] = time.time()-t
t = time.time()
@ -145,9 +153,10 @@ class PPOTrainer:
random.shuffle(idxs)
for i in range(bs):
idx = idxs[i]
train_stats = self.train_minibatch(logprobs[idx:idx+1], values[idx:idx+1],
rewards[idx:idx+1], query[idx:idx+1],
response[idx:idx+1], model_input[idx:idx+1])
train_stats = self.train_minibatch(logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0),
rewards[idx].unsqueeze(0), queries[idx].unsqueeze(0),
responses[idx].unsqueeze(0),
torch.cat([queries[idx],responses[idx]]).unsqueeze(0))
all_stats.append(train_stats)
timing['time/ppo/optimize_step'] = time.time()-t
@ -156,11 +165,12 @@ class PPOTrainer:
# reshape advantages/ratios such that they are not averaged.
train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0)
train_stats['policy/advantages'] = torch.nan_to_num(train_stats['policy/advantages'], WANDB_PADDING)
train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0)
stats = self.record_step_stats(scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs,
non_score_reward=non_score_reward, train_stats=train_stats,
kl_coef=kl_coef)
kl_coef=self.kl_ctl.value)
stats = stats_to_np(stats)
timing['time/ppo/calc_stats'] = time.time()-t
@ -170,24 +180,30 @@ class PPOTrainer:
stats.update(timing)
return stats
def batched_forward_pass(self, model_input, gen_len):
def batched_forward_pass(self, queries, responses):
"""Calculate model outputs in multiple batches."""
bs = self.ppo_params['batch_size']
fbs = self.ppo_params['forward_batch_size']
logprobs = []
ref_logprobs = []
values = []
all_logprobs = []
all_ref_logprobs = []
all_values = []
for i in range(int(self.ppo_params['batch_size']/fbs)):
m_input = model_input[i*fbs:(i+1)*fbs]
logits, _, v = self.model(m_input)
ref_logits, _, _ = self.ref_model(m_input)
values.append(v[:, -gen_len-1:-1].detach())
logprobs.append(logprobs_from_logits(logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())
ref_logprobs.append(logprobs_from_logits(ref_logits[:,:-1,:], m_input[:,1:])[:, -gen_len:].detach())
return torch.cat(logprobs), torch.cat(ref_logprobs), torch.cat(values)
for i in range(int(bs/fbs)):
query_batch = queries[i*fbs:(i+1)*fbs]
response_batch = responses[i*fbs:(i+1)*fbs]
input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])["input_ids"]
with torch.no_grad():
logits, _, v = self.model(input_ids)
ref_logits, _, _ = self.ref_model(input_ids)
logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:])
ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:])
for j in range(fbs):
start = len(query_batch[j])-1
end = len(query_batch[j]) + len(response_batch[j])-1
all_values.append(v[j, start-1:end-1])
all_logprobs.append(logprobs[j, start:end])
all_ref_logprobs.append(ref_logprobs[j, start:end])
return all_logprobs, all_ref_logprobs, all_values
def train_minibatch(self, logprobs, values, rewards, query, response, model_input):
"""Train one PPO minibatch"""
@ -200,11 +216,15 @@ class PPOTrainer:
def compute_rewards(self, scores, logprobs, ref_logprobs):
"""Compute per token rewards from scores and KL-penalty."""
kl = logprobs - ref_logprobs
non_score_reward = -self.kl_ctl.value * kl
rewards = non_score_reward.clone().detach()
rewards[:, -1] += scores
return rewards, non_score_reward, self.kl_ctl.value
rewards, non_score_rewards = [], []
for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs):
kl = logprob - ref_logprob
non_score_reward = -self.kl_ctl.value * kl
non_score_rewards.append(non_score_reward)
reward = non_score_reward.clone()
reward[-1] += score
rewards.append(reward)
return rewards, non_score_rewards
def loss(self, old_logprobs, values, rewards, query, response, model_input):
"""Calculate policy and value losses."""
@ -269,13 +289,13 @@ class PPOTrainer:
def record_step_stats(self, kl_coef, **data):
"""Record training step statistics."""
kl = data['logprobs'] - data['ref_logprobs']
mean_kl = torch.mean(torch.sum(kl, axis=-1))
mean_entropy = torch.mean(torch.sum(-data['logprobs'], axis=1))
mean_non_score_reward =torch.mean(torch.sum(data['non_score_reward'], axis=1))
kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])]
mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list]))
mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']]))
mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']]))
stats = {
'objective/kl': mean_kl,
'objective/kl_dist': kl,
'objective/kl_dist': kl_list,
'objective/logprobs': data['logprobs'],
'objective/ref_logprobs': data['ref_logprobs'],
'objective/kl_coef': kl_coef,